aifeifei798 commited on
Commit
7aec2ba
·
verified ·
1 Parent(s): dde1da4

Upload train_zimage_lora.py

Browse files
Files changed (1) hide show
  1. train_zimage_lora.py +42 -13
train_zimage_lora.py CHANGED
@@ -371,22 +371,51 @@ def main():
371
 
372
  if args.resume_from_checkpoint:
373
  logger.info(f"Resuming from checkpoint {args.resume_from_checkpoint}")
374
- # accelerator.load_state(args.resume_from_checkpoint) # accelerate launch 会自动做,这里不用写
375
 
376
- # 从文件夹名字里提取我们上次训练到了哪一步,例如 "checkpoint-500" -> 500
377
- path = os.path.basename(args.resume_from_checkpoint)
378
- global_step = int(path.split("-")[1])
379
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
380
  # 计算我们应该从哪个 epoch 和哪一步开始
381
  first_epoch = global_step // num_update_steps_per_epoch
382
  resume_step = global_step % num_update_steps_per_epoch
383
- logger.info(f"Global step restored to {global_step}, resuming from epoch {first_epoch}, step {resume_step}")
384
-
385
- logger.info("***** Running training (Final Fix) *****")
386
-
387
- # 更新进度条,让它从正确的地方开始
388
- # progress_bar = tqdm(range(global_step, args.max_train_steps), initial=global_step, desc="Steps", disable=not accelerator.is_local_main_process)
389
- # progress_bar.update(global_step)
390
 
391
  for epoch in range(first_epoch, args.num_train_epochs):
392
  transformer.train()
@@ -474,7 +503,7 @@ def main():
474
  break
475
  # ============================
476
 
477
- # === 这里的缩进是 4 个空格 ===
478
  if accelerator.is_main_process:
479
  transformer = accelerator.unwrap_model(transformer)
480
  transformer = transformer.to(torch.float32)
 
371
 
372
  if args.resume_from_checkpoint:
373
  logger.info(f"Resuming from checkpoint {args.resume_from_checkpoint}")
 
374
 
375
+ # 【核心修复 1】总是调用 load_state,这样无论是否使用 accelerate launch 都能工作
376
+ accelerator.load_state(args.resume_from_checkpoint)
377
+
378
+ # 【核心修复 2】更可靠的 global_step 恢复逻辑
379
+ checkpoint_path = args.resume_from_checkpoint.rstrip('/')
380
+
381
+ # 方法 1:尝试从文件夹名提取
382
+ folder_name = os.path.basename(checkpoint_path)
383
+ if folder_name.startswith("checkpoint-") and "-" in folder_name:
384
+ try:
385
+ global_step = int(folder_name.split("-")[1])
386
+ except:
387
+ pass
388
+
389
+ # 方法 2:如果无法提取,尝试从 saved_state.json 读取(accelerate 保存的)
390
+ if global_step == 0:
391
+ saved_state_path = os.path.join(checkpoint_path, "saved_state.json")
392
+ if os.path.exists(saved_state_path):
393
+ try:
394
+ import json
395
+ with open(saved_state_path, "r") as f:
396
+ saved_state = json.load(f)
397
+ global_step = saved_state.get("global_step", 0)
398
+ except:
399
+ pass
400
+
401
+ # 方法 3:从进度条保存的文件读取(如果有)
402
+ if global_step == 0:
403
+ progress_path = os.path.join(checkpoint_path, "progress.json")
404
+ if os.path.exists(progress_path):
405
+ try:
406
+ import json
407
+ with open(progress_path, "r") as f:
408
+ progress_data = json.load(f)
409
+ global_step = progress_data.get("step", 0)
410
+ except:
411
+ pass
412
+
413
+ logger.info(f"恢复的 Global step: {global_step}")
414
+
415
  # 计算我们应该从哪个 epoch 和哪一步开始
416
  first_epoch = global_step // num_update_steps_per_epoch
417
  resume_step = global_step % num_update_steps_per_epoch
418
+ logger.info(f"Resuming from epoch {first_epoch}, step {resume_step}")
 
 
 
 
 
 
419
 
420
  for epoch in range(first_epoch, args.num_train_epochs):
421
  transformer.train()
 
503
  break
504
  # ============================
505
 
506
+ # === 这里的缩进是 4 个空格 ===
507
  if accelerator.is_main_process:
508
  transformer = accelerator.unwrap_model(transformer)
509
  transformer = transformer.to(torch.float32)