aifeifei798 commited on
Commit
0a78742
·
verified ·
1 Parent(s): d9e02a2

Upload 2 files

Browse files
Files changed (2) hide show
  1. run.sh +4 -4
  2. train_zimage_lora.py +36 -58
run.sh CHANGED
@@ -7,12 +7,11 @@
7
  # --resume_from_checkpoint="feifei-zimage-lora/checkpoint-100"
8
 
9
  accelerate launch --mixed_precision="bf16" train_zimage_lora.py \
10
- --pretrained_model_name_or_path="./Z-Image-Turbo" \
11
- --train_data_dir="./fiefei_pic" \
12
  --resolution=1024 \
13
  --train_batch_size=1 \
14
  --gradient_accumulation_steps=4 \
15
- --max_train_steps=1000 \
16
  --learning_rate=1e-4 \
17
  --mixed_precision="bf16" \
18
  --output_dir="feifei-zimage-lora" \
@@ -21,4 +20,5 @@ accelerate launch --mixed_precision="bf16" train_zimage_lora.py \
21
  --gradient_checkpointing \
22
  --use_8bit_adam \
23
  --checkpointing_steps=100 \
24
- --checkpoints_total_limit=3
 
 
7
  # --resume_from_checkpoint="feifei-zimage-lora/checkpoint-100"
8
 
9
  accelerate launch --mixed_precision="bf16" train_zimage_lora.py \
10
+ --pretrained_model_name_or_path="../../../smodels/Z-Image-Turbo" \
11
+ --train_data_dir="../../../datasets/fiefei_pic" \
12
  --resolution=1024 \
13
  --train_batch_size=1 \
14
  --gradient_accumulation_steps=4 \
 
15
  --learning_rate=1e-4 \
16
  --mixed_precision="bf16" \
17
  --output_dir="feifei-zimage-lora" \
 
20
  --gradient_checkpointing \
21
  --use_8bit_adam \
22
  --checkpointing_steps=100 \
23
+ --checkpoints_total_limit=3 \
24
+ --max_train_steps=200
train_zimage_lora.py CHANGED
@@ -4,7 +4,7 @@ import argparse
4
  import logging
5
  import math
6
  import os
7
- from safetensors.torch import save_file # Add this import
8
  import random
9
  import shutil
10
  import glob
@@ -58,7 +58,7 @@ check_min_version("0.24.0")
58
  logger = get_logger(__name__, log_level="INFO")
59
 
60
  def parse_args():
61
- parser = argparse.ArgumentParser(description="Simple example of a training script.")
62
  parser.add_argument("--pretrained_model_name_or_path", type=str, required=True)
63
  parser.add_argument("--revision", type=str, default=None)
64
  parser.add_argument("--variant", type=str, default=None)
@@ -242,7 +242,6 @@ def main():
242
  noise_scheduler = pipe.scheduler
243
  del pipe
244
 
245
- # Analyze Params
246
  forward_signature = inspect.signature(transformer.forward)
247
  params = forward_signature.parameters
248
  param_names = list(params.keys())
@@ -252,11 +251,14 @@ def main():
252
  if "x" in param_names: input_arg = "x"
253
  elif "sample" in param_names: input_arg = "sample"
254
 
 
 
 
255
  cond_arg = "encoder_hidden_states"
256
  if "cap_feats" in param_names: cond_arg = "cap_feats"
257
  elif "context" in param_names: cond_arg = "context"
258
 
259
- logger.info(f"Mapping: Input='{input_arg}', Cond='{cond_arg}'")
260
 
261
  transformer.requires_grad_(False)
262
  vae.requires_grad_(False)
@@ -360,25 +362,17 @@ def main():
360
  transformer, optimizer, train_dataloader, lr_scheduler
361
  )
362
 
363
- # 计算每个 epoch 需要多少步
364
  num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
365
 
366
- # === 【核心修复:智能恢复训练状态】 ===
367
- # 默认从 0 开始
368
  global_step = 0
369
  first_epoch = 0
370
  resume_step = 0
371
 
372
  if args.resume_from_checkpoint:
373
  logger.info(f"Resuming from checkpoint {args.resume_from_checkpoint}")
374
-
375
- # 【核心修复 1】总是调用 load_state,这样无论是否使用 accelerate launch 都能工作
376
  accelerator.load_state(args.resume_from_checkpoint)
377
 
378
- # 【核心修复 2】更可靠的 global_step 恢复逻辑
379
  checkpoint_path = args.resume_from_checkpoint.rstrip('/')
380
-
381
- # 方法 1:尝试从文件夹名提取
382
  folder_name = os.path.basename(checkpoint_path)
383
  if folder_name.startswith("checkpoint-") and "-" in folder_name:
384
  try:
@@ -386,7 +380,6 @@ def main():
386
  except:
387
  pass
388
 
389
- # 方法 2:如果无法提取,尝试从 saved_state.json 读取(accelerate 保存的)
390
  if global_step == 0:
391
  saved_state_path = os.path.join(checkpoint_path, "saved_state.json")
392
  if os.path.exists(saved_state_path):
@@ -397,33 +390,16 @@ def main():
397
  global_step = saved_state.get("global_step", 0)
398
  except:
399
  pass
400
-
401
- # 方法 3:从进度条保存的文件读取(如果有)
402
- if global_step == 0:
403
- progress_path = os.path.join(checkpoint_path, "progress.json")
404
- if os.path.exists(progress_path):
405
- try:
406
- import json
407
- with open(progress_path, "r") as f:
408
- progress_data = json.load(f)
409
- global_step = progress_data.get("step", 0)
410
- except:
411
- pass
412
-
413
- logger.info(f"恢复的 Global step: {global_step}")
414
-
415
- # 计算我们应该从哪个 epoch 和哪一步开始
416
  first_epoch = global_step // num_update_steps_per_epoch
417
  resume_step = global_step % num_update_steps_per_epoch
418
  logger.info(f"Resuming from epoch {first_epoch}, step {resume_step}")
419
 
 
420
  for epoch in range(first_epoch, args.num_train_epochs):
421
  transformer.train()
422
-
423
- # 定义一个累加器,用来计算平均 Loss
424
  loss_accumulator = 0.0
425
 
426
- # === 【数据加载器跳过逻辑】 ===
427
  if args.resume_from_checkpoint and epoch == first_epoch:
428
  train_dataloader_skip = accelerator.skip_first_batches(train_dataloader, resume_step * args.train_batch_size)
429
  else:
@@ -431,40 +407,56 @@ def main():
431
 
432
  for step, batch in enumerate(train_dataloader_skip):
433
  with accelerator.accumulate(transformer):
434
- # --- 数据准备和模型前向传播 (你的代码保持不变) ---
435
  latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
436
  latents = latents * vae.config.scaling_factor
 
437
  noise = torch.randn_like(latents)
438
  bsz = latents.shape[0]
439
  timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device).long()
 
440
  sigmas = timesteps.flatten() / noise_scheduler.config.num_train_timesteps
441
  while sigmas.ndim < latents.ndim:
442
  sigmas = sigmas.unsqueeze(-1)
 
443
  noisy_latents = (1.0 - sigmas) * latents + sigmas * noise
444
  target = noise - latents
 
 
445
  noisy_latents = noisy_latents.unsqueeze(2)
 
446
  encoder_hidden_states = batch["encoder_hidden_states"].to(dtype=weight_dtype)
 
447
  forward_kwargs = {
448
  input_arg: noisy_latents,
449
- "t": timesteps,
450
  cond_arg: encoder_hidden_states,
451
  "return_dict": False
452
  }
 
 
 
 
 
 
453
  model_pred = transformer(**forward_kwargs)[0]
454
 
455
- # --- 各种形状修复 (你的代码保持不变) ---
456
  if isinstance(model_pred, list): model_pred = model_pred[0]
457
- if model_pred.ndim == 5: model_pred = model_pred.squeeze(2)
 
 
 
 
 
 
 
 
 
458
  if model_pred.shape != target.shape:
459
  if model_pred.shape[0] == target.shape[1] and model_pred.shape[1] == target.shape[0]:
 
460
  model_pred = model_pred.transpose(0, 1)
461
- elif model_pred.numel() == target.numel():
462
- model_pred = model_pred.reshape(target.shape)
463
 
464
- # --- Loss 计算和累加 ---
465
  loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
466
-
467
- # 【核心修改 1】:把每一小步的 loss 加到累加器里
468
  loss_accumulator += loss.detach().item()
469
 
470
  accelerator.backward(loss)
@@ -474,7 +466,6 @@ def main():
474
  lr_scheduler.step()
475
  optimizer.zero_grad()
476
 
477
- # --- 模型更新后的操作 ---
478
  if accelerator.sync_gradients:
479
  global_step += 1
480
 
@@ -483,45 +474,32 @@ def main():
483
  save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
484
  accelerator.save_state(save_path)
485
 
486
- # 【核心修改 2】:只在模型更新时,打印平均 Loss
487
  if accelerator.is_main_process:
488
- # 计算过去 N 步的平均 Loss
489
  avg_loss = loss_accumulator / args.gradient_accumulation_steps
490
-
491
- # 打印格式化的结果
492
  print(f"Steps: {global_step}/{args.max_train_steps} | Loss: {avg_loss:.4f}")
493
-
494
- # 清零累加器,为下一轮做准备
495
  loss_accumulator = 0.0
496
 
497
  if global_step >= args.max_train_steps:
498
  break
499
 
500
- # === 【核心修改:这里加两行】 ===
501
- # 如果步数到了,强制跳出外层的 Epoch 循环,否则它会空转
502
  if global_step >= args.max_train_steps:
503
  break
504
- # ============================
505
 
506
- # === 这里的缩进是 4 个空格 ===
507
  if accelerator.is_main_process:
508
  transformer = accelerator.unwrap_model(transformer)
509
  transformer = transformer.to(torch.float32)
510
  transformer_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(transformer))
511
 
512
- # === 【核心修复】给参数名加上 "transformer." 前缀 ===
513
- # 这一步非常关键,没有它,load_lora_weights 就认不出来
514
  new_state_dict = {}
515
  for k, v in transformer_lora_state_dict.items():
516
  new_state_dict[f"transformer.{k}"] = v
517
- # ================================================
518
 
519
- # 使用 safetensors 保存加上前缀后的权重
520
  save_path = os.path.join(args.output_dir, "pytorch_lora_weights.safetensors")
521
- save_file(new_state_dict, save_path) # 注意这里传的是 new_state_dict
522
  logger.info(f"Saved LoRA weights to {save_path}")
523
 
524
  accelerator.end_training()
525
 
526
  if __name__ == "__main__":
527
- main()
 
4
  import logging
5
  import math
6
  import os
7
+ from safetensors.torch import save_file
8
  import random
9
  import shutil
10
  import glob
 
58
  logger = get_logger(__name__, log_level="INFO")
59
 
60
  def parse_args():
61
+ parser = argparse.ArgumentParser(description="Fixed Training script V3.")
62
  parser.add_argument("--pretrained_model_name_or_path", type=str, required=True)
63
  parser.add_argument("--revision", type=str, default=None)
64
  parser.add_argument("--variant", type=str, default=None)
 
242
  noise_scheduler = pipe.scheduler
243
  del pipe
244
 
 
245
  forward_signature = inspect.signature(transformer.forward)
246
  params = forward_signature.parameters
247
  param_names = list(params.keys())
 
251
  if "x" in param_names: input_arg = "x"
252
  elif "sample" in param_names: input_arg = "sample"
253
 
254
+ time_arg = "t"
255
+ if "timestep" in param_names: time_arg = "timestep"
256
+
257
  cond_arg = "encoder_hidden_states"
258
  if "cap_feats" in param_names: cond_arg = "cap_feats"
259
  elif "context" in param_names: cond_arg = "context"
260
 
261
+ logger.info(f"Mapping: Input='{input_arg}', Time='{time_arg}', Cond='{cond_arg}'")
262
 
263
  transformer.requires_grad_(False)
264
  vae.requires_grad_(False)
 
362
  transformer, optimizer, train_dataloader, lr_scheduler
363
  )
364
 
 
365
  num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
366
 
 
 
367
  global_step = 0
368
  first_epoch = 0
369
  resume_step = 0
370
 
371
  if args.resume_from_checkpoint:
372
  logger.info(f"Resuming from checkpoint {args.resume_from_checkpoint}")
 
 
373
  accelerator.load_state(args.resume_from_checkpoint)
374
 
 
375
  checkpoint_path = args.resume_from_checkpoint.rstrip('/')
 
 
376
  folder_name = os.path.basename(checkpoint_path)
377
  if folder_name.startswith("checkpoint-") and "-" in folder_name:
378
  try:
 
380
  except:
381
  pass
382
 
 
383
  if global_step == 0:
384
  saved_state_path = os.path.join(checkpoint_path, "saved_state.json")
385
  if os.path.exists(saved_state_path):
 
390
  global_step = saved_state.get("global_step", 0)
391
  except:
392
  pass
393
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
394
  first_epoch = global_step // num_update_steps_per_epoch
395
  resume_step = global_step % num_update_steps_per_epoch
396
  logger.info(f"Resuming from epoch {first_epoch}, step {resume_step}")
397
 
398
+ # 训练循环
399
  for epoch in range(first_epoch, args.num_train_epochs):
400
  transformer.train()
 
 
401
  loss_accumulator = 0.0
402
 
 
403
  if args.resume_from_checkpoint and epoch == first_epoch:
404
  train_dataloader_skip = accelerator.skip_first_batches(train_dataloader, resume_step * args.train_batch_size)
405
  else:
 
407
 
408
  for step, batch in enumerate(train_dataloader_skip):
409
  with accelerator.accumulate(transformer):
 
410
  latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
411
  latents = latents * vae.config.scaling_factor
412
+
413
  noise = torch.randn_like(latents)
414
  bsz = latents.shape[0]
415
  timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device).long()
416
+
417
  sigmas = timesteps.flatten() / noise_scheduler.config.num_train_timesteps
418
  while sigmas.ndim < latents.ndim:
419
  sigmas = sigmas.unsqueeze(-1)
420
+
421
  noisy_latents = (1.0 - sigmas) * latents + sigmas * noise
422
  target = noise - latents
423
+
424
+ # Z-Image 输入需要 5D
425
  noisy_latents = noisy_latents.unsqueeze(2)
426
+
427
  encoder_hidden_states = batch["encoder_hidden_states"].to(dtype=weight_dtype)
428
+
429
  forward_kwargs = {
430
  input_arg: noisy_latents,
431
+ time_arg: timesteps,
432
  cond_arg: encoder_hidden_states,
433
  "return_dict": False
434
  }
435
+
436
+ if "pooled_projections" in param_names:
437
+ forward_kwargs["pooled_projections"] = torch.zeros(
438
+ (bsz, 2048), device=latents.device, dtype=weight_dtype
439
+ )
440
+
441
  model_pred = transformer(**forward_kwargs)[0]
442
 
 
443
  if isinstance(model_pred, list): model_pred = model_pred[0]
444
+
445
+ # === 【最后一步核心修复】处理维度颠倒 ===
446
+
447
+ # 1. 如果输出是 5 维的,先把那个 1 维压扁
448
+ if model_pred.ndim == 5:
449
+ model_pred = model_pred.squeeze(2)
450
+
451
+ # 2. 如果输出形状和 Target 形状颠倒了 (Channel <-> Batch),则转置回来
452
+ # Target: [1, 16, 128, 128] (Batch, Channel, H, W)
453
+ # Pred: [16, 1, 128, 128] (Channel, Batch, H, W) -> 需要修正
454
  if model_pred.shape != target.shape:
455
  if model_pred.shape[0] == target.shape[1] and model_pred.shape[1] == target.shape[0]:
456
+ # 执行转置,把 Channel 和 Batch 换回来
457
  model_pred = model_pred.transpose(0, 1)
 
 
458
 
 
459
  loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
 
 
460
  loss_accumulator += loss.detach().item()
461
 
462
  accelerator.backward(loss)
 
466
  lr_scheduler.step()
467
  optimizer.zero_grad()
468
 
 
469
  if accelerator.sync_gradients:
470
  global_step += 1
471
 
 
474
  save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
475
  accelerator.save_state(save_path)
476
 
 
477
  if accelerator.is_main_process:
 
478
  avg_loss = loss_accumulator / args.gradient_accumulation_steps
 
 
479
  print(f"Steps: {global_step}/{args.max_train_steps} | Loss: {avg_loss:.4f}")
 
 
480
  loss_accumulator = 0.0
481
 
482
  if global_step >= args.max_train_steps:
483
  break
484
 
 
 
485
  if global_step >= args.max_train_steps:
486
  break
 
487
 
488
+ # 保存权重
489
  if accelerator.is_main_process:
490
  transformer = accelerator.unwrap_model(transformer)
491
  transformer = transformer.to(torch.float32)
492
  transformer_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(transformer))
493
 
 
 
494
  new_state_dict = {}
495
  for k, v in transformer_lora_state_dict.items():
496
  new_state_dict[f"transformer.{k}"] = v
 
497
 
 
498
  save_path = os.path.join(args.output_dir, "pytorch_lora_weights.safetensors")
499
+ save_file(new_state_dict, save_path)
500
  logger.info(f"Saved LoRA weights to {save_path}")
501
 
502
  accelerator.end_training()
503
 
504
  if __name__ == "__main__":
505
+ main()