Spaces:
Running
on
Zero
Running
on
Zero
Sync from GitHub repo
Browse filesThis Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there
src/f5_tts/model/trainer.py
CHANGED
|
@@ -51,7 +51,7 @@ class Trainer:
|
|
| 51 |
mel_spec_type: str = "vocos", # "vocos" | "bigvgan"
|
| 52 |
is_local_vocoder: bool = False, # use local path vocoder
|
| 53 |
local_vocoder_path: str = "", # local vocoder path
|
| 54 |
-
|
| 55 |
):
|
| 56 |
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
| 57 |
|
|
@@ -73,8 +73,8 @@ class Trainer:
|
|
| 73 |
else:
|
| 74 |
init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name}}
|
| 75 |
|
| 76 |
-
if not
|
| 77 |
-
|
| 78 |
"epochs": epochs,
|
| 79 |
"learning_rate": learning_rate,
|
| 80 |
"num_warmup_updates": num_warmup_updates,
|
|
@@ -85,11 +85,11 @@ class Trainer:
|
|
| 85 |
"max_grad_norm": max_grad_norm,
|
| 86 |
"noise_scheduler": noise_scheduler,
|
| 87 |
}
|
| 88 |
-
|
| 89 |
self.accelerator.init_trackers(
|
| 90 |
project_name=wandb_project,
|
| 91 |
init_kwargs=init_kwargs,
|
| 92 |
-
config=
|
| 93 |
)
|
| 94 |
|
| 95 |
elif self.logger == "tensorboard":
|
|
|
|
| 51 |
mel_spec_type: str = "vocos", # "vocos" | "bigvgan"
|
| 52 |
is_local_vocoder: bool = False, # use local path vocoder
|
| 53 |
local_vocoder_path: str = "", # local vocoder path
|
| 54 |
+
model_cfg_dict: dict = dict(), # training config
|
| 55 |
):
|
| 56 |
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
| 57 |
|
|
|
|
| 73 |
else:
|
| 74 |
init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name}}
|
| 75 |
|
| 76 |
+
if not model_cfg_dict:
|
| 77 |
+
model_cfg_dict = {
|
| 78 |
"epochs": epochs,
|
| 79 |
"learning_rate": learning_rate,
|
| 80 |
"num_warmup_updates": num_warmup_updates,
|
|
|
|
| 85 |
"max_grad_norm": max_grad_norm,
|
| 86 |
"noise_scheduler": noise_scheduler,
|
| 87 |
}
|
| 88 |
+
model_cfg_dict["gpus"] = self.accelerator.num_processes
|
| 89 |
self.accelerator.init_trackers(
|
| 90 |
project_name=wandb_project,
|
| 91 |
init_kwargs=init_kwargs,
|
| 92 |
+
config=model_cfg_dict,
|
| 93 |
)
|
| 94 |
|
| 95 |
elif self.logger == "tensorboard":
|