Upload train_zimage_lora.py
Browse files- train_zimage_lora.py +42 -13
train_zimage_lora.py
CHANGED
|
@@ -371,22 +371,51 @@ def main():
|
|
| 371 |
|
| 372 |
if args.resume_from_checkpoint:
|
| 373 |
logger.info(f"Resuming from checkpoint {args.resume_from_checkpoint}")
|
| 374 |
-
# accelerator.load_state(args.resume_from_checkpoint) # accelerate launch 会自动做,这里不用写
|
| 375 |
|
| 376 |
-
#
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 380 |
# 计算我们应该从哪个 epoch 和哪一步开始
|
| 381 |
first_epoch = global_step // num_update_steps_per_epoch
|
| 382 |
resume_step = global_step % num_update_steps_per_epoch
|
| 383 |
-
logger.info(f"
|
| 384 |
-
|
| 385 |
-
logger.info("***** Running training (Final Fix) *****")
|
| 386 |
-
|
| 387 |
-
# 更新进度条,让它从正确的地方开始
|
| 388 |
-
# progress_bar = tqdm(range(global_step, args.max_train_steps), initial=global_step, desc="Steps", disable=not accelerator.is_local_main_process)
|
| 389 |
-
# progress_bar.update(global_step)
|
| 390 |
|
| 391 |
for epoch in range(first_epoch, args.num_train_epochs):
|
| 392 |
transformer.train()
|
|
@@ -474,7 +503,7 @@ def main():
|
|
| 474 |
break
|
| 475 |
# ============================
|
| 476 |
|
| 477 |
-
# === 这里的缩进是 4 个空格 ===
|
| 478 |
if accelerator.is_main_process:
|
| 479 |
transformer = accelerator.unwrap_model(transformer)
|
| 480 |
transformer = transformer.to(torch.float32)
|
|
|
|
| 371 |
|
| 372 |
if args.resume_from_checkpoint:
|
| 373 |
logger.info(f"Resuming from checkpoint {args.resume_from_checkpoint}")
|
|
|
|
| 374 |
|
| 375 |
+
# 【核心修复 1】总是调用 load_state,这样无论是否使用 accelerate launch 都能工作
|
| 376 |
+
accelerator.load_state(args.resume_from_checkpoint)
|
| 377 |
+
|
| 378 |
+
# 【核心修复 2】更可靠的 global_step 恢复逻辑
|
| 379 |
+
checkpoint_path = args.resume_from_checkpoint.rstrip('/')
|
| 380 |
+
|
| 381 |
+
# 方法 1:尝试从文件夹名提取
|
| 382 |
+
folder_name = os.path.basename(checkpoint_path)
|
| 383 |
+
if folder_name.startswith("checkpoint-") and "-" in folder_name:
|
| 384 |
+
try:
|
| 385 |
+
global_step = int(folder_name.split("-")[1])
|
| 386 |
+
except:
|
| 387 |
+
pass
|
| 388 |
+
|
| 389 |
+
# 方法 2:如果无法提取,尝试从 saved_state.json 读取(accelerate 保存的)
|
| 390 |
+
if global_step == 0:
|
| 391 |
+
saved_state_path = os.path.join(checkpoint_path, "saved_state.json")
|
| 392 |
+
if os.path.exists(saved_state_path):
|
| 393 |
+
try:
|
| 394 |
+
import json
|
| 395 |
+
with open(saved_state_path, "r") as f:
|
| 396 |
+
saved_state = json.load(f)
|
| 397 |
+
global_step = saved_state.get("global_step", 0)
|
| 398 |
+
except:
|
| 399 |
+
pass
|
| 400 |
+
|
| 401 |
+
# 方法 3:从进度条保存的文件读取(如果有)
|
| 402 |
+
if global_step == 0:
|
| 403 |
+
progress_path = os.path.join(checkpoint_path, "progress.json")
|
| 404 |
+
if os.path.exists(progress_path):
|
| 405 |
+
try:
|
| 406 |
+
import json
|
| 407 |
+
with open(progress_path, "r") as f:
|
| 408 |
+
progress_data = json.load(f)
|
| 409 |
+
global_step = progress_data.get("step", 0)
|
| 410 |
+
except:
|
| 411 |
+
pass
|
| 412 |
+
|
| 413 |
+
logger.info(f"恢复的 Global step: {global_step}")
|
| 414 |
+
|
| 415 |
# 计算我们应该从哪个 epoch 和哪一步开始
|
| 416 |
first_epoch = global_step // num_update_steps_per_epoch
|
| 417 |
resume_step = global_step % num_update_steps_per_epoch
|
| 418 |
+
logger.info(f"Resuming from epoch {first_epoch}, step {resume_step}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 419 |
|
| 420 |
for epoch in range(first_epoch, args.num_train_epochs):
|
| 421 |
transformer.train()
|
|
|
|
| 503 |
break
|
| 504 |
# ============================
|
| 505 |
|
| 506 |
+
# === 这里的缩进是 4 个空格 ===
|
| 507 |
if accelerator.is_main_process:
|
| 508 |
transformer = accelerator.unwrap_model(transformer)
|
| 509 |
transformer = transformer.to(torch.float32)
|