Upload 2 files
Browse files- run.sh +4 -4
- train_zimage_lora.py +36 -58
run.sh
CHANGED
|
@@ -7,12 +7,11 @@
|
|
| 7 |
# --resume_from_checkpoint="feifei-zimage-lora/checkpoint-100"
|
| 8 |
|
| 9 |
accelerate launch --mixed_precision="bf16" train_zimage_lora.py \
|
| 10 |
-
--pretrained_model_name_or_path="
|
| 11 |
-
--train_data_dir="
|
| 12 |
--resolution=1024 \
|
| 13 |
--train_batch_size=1 \
|
| 14 |
--gradient_accumulation_steps=4 \
|
| 15 |
-
--max_train_steps=1000 \
|
| 16 |
--learning_rate=1e-4 \
|
| 17 |
--mixed_precision="bf16" \
|
| 18 |
--output_dir="feifei-zimage-lora" \
|
|
@@ -21,4 +20,5 @@ accelerate launch --mixed_precision="bf16" train_zimage_lora.py \
|
|
| 21 |
--gradient_checkpointing \
|
| 22 |
--use_8bit_adam \
|
| 23 |
--checkpointing_steps=100 \
|
| 24 |
-
--checkpoints_total_limit=3
|
|
|
|
|
|
| 7 |
# --resume_from_checkpoint="feifei-zimage-lora/checkpoint-100"
|
| 8 |
|
| 9 |
accelerate launch --mixed_precision="bf16" train_zimage_lora.py \
|
| 10 |
+
--pretrained_model_name_or_path="../../../smodels/Z-Image-Turbo" \
|
| 11 |
+
--train_data_dir="../../../datasets/fiefei_pic" \
|
| 12 |
--resolution=1024 \
|
| 13 |
--train_batch_size=1 \
|
| 14 |
--gradient_accumulation_steps=4 \
|
|
|
|
| 15 |
--learning_rate=1e-4 \
|
| 16 |
--mixed_precision="bf16" \
|
| 17 |
--output_dir="feifei-zimage-lora" \
|
|
|
|
| 20 |
--gradient_checkpointing \
|
| 21 |
--use_8bit_adam \
|
| 22 |
--checkpointing_steps=100 \
|
| 23 |
+
--checkpoints_total_limit=3 \
|
| 24 |
+
--max_train_steps=200
|
train_zimage_lora.py
CHANGED
|
@@ -4,7 +4,7 @@ import argparse
|
|
| 4 |
import logging
|
| 5 |
import math
|
| 6 |
import os
|
| 7 |
-
from safetensors.torch import save_file
|
| 8 |
import random
|
| 9 |
import shutil
|
| 10 |
import glob
|
|
@@ -58,7 +58,7 @@ check_min_version("0.24.0")
|
|
| 58 |
logger = get_logger(__name__, log_level="INFO")
|
| 59 |
|
| 60 |
def parse_args():
|
| 61 |
-
parser = argparse.ArgumentParser(description="
|
| 62 |
parser.add_argument("--pretrained_model_name_or_path", type=str, required=True)
|
| 63 |
parser.add_argument("--revision", type=str, default=None)
|
| 64 |
parser.add_argument("--variant", type=str, default=None)
|
|
@@ -242,7 +242,6 @@ def main():
|
|
| 242 |
noise_scheduler = pipe.scheduler
|
| 243 |
del pipe
|
| 244 |
|
| 245 |
-
# Analyze Params
|
| 246 |
forward_signature = inspect.signature(transformer.forward)
|
| 247 |
params = forward_signature.parameters
|
| 248 |
param_names = list(params.keys())
|
|
@@ -252,11 +251,14 @@ def main():
|
|
| 252 |
if "x" in param_names: input_arg = "x"
|
| 253 |
elif "sample" in param_names: input_arg = "sample"
|
| 254 |
|
|
|
|
|
|
|
|
|
|
| 255 |
cond_arg = "encoder_hidden_states"
|
| 256 |
if "cap_feats" in param_names: cond_arg = "cap_feats"
|
| 257 |
elif "context" in param_names: cond_arg = "context"
|
| 258 |
|
| 259 |
-
logger.info(f"Mapping: Input='{input_arg}', Cond='{cond_arg}'")
|
| 260 |
|
| 261 |
transformer.requires_grad_(False)
|
| 262 |
vae.requires_grad_(False)
|
|
@@ -360,25 +362,17 @@ def main():
|
|
| 360 |
transformer, optimizer, train_dataloader, lr_scheduler
|
| 361 |
)
|
| 362 |
|
| 363 |
-
# 计算每个 epoch 需要多少步
|
| 364 |
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
| 365 |
|
| 366 |
-
# === 【核心修复:智能恢复训练状态】 ===
|
| 367 |
-
# 默认从 0 开始
|
| 368 |
global_step = 0
|
| 369 |
first_epoch = 0
|
| 370 |
resume_step = 0
|
| 371 |
|
| 372 |
if args.resume_from_checkpoint:
|
| 373 |
logger.info(f"Resuming from checkpoint {args.resume_from_checkpoint}")
|
| 374 |
-
|
| 375 |
-
# 【核心修复 1】总是调用 load_state,这样无论是否使用 accelerate launch 都能工作
|
| 376 |
accelerator.load_state(args.resume_from_checkpoint)
|
| 377 |
|
| 378 |
-
# 【核心修复 2】更可靠的 global_step 恢复逻辑
|
| 379 |
checkpoint_path = args.resume_from_checkpoint.rstrip('/')
|
| 380 |
-
|
| 381 |
-
# 方法 1:尝试从文件夹名提取
|
| 382 |
folder_name = os.path.basename(checkpoint_path)
|
| 383 |
if folder_name.startswith("checkpoint-") and "-" in folder_name:
|
| 384 |
try:
|
|
@@ -386,7 +380,6 @@ def main():
|
|
| 386 |
except:
|
| 387 |
pass
|
| 388 |
|
| 389 |
-
# 方法 2:如果无法提取,尝试从 saved_state.json 读取(accelerate 保存的)
|
| 390 |
if global_step == 0:
|
| 391 |
saved_state_path = os.path.join(checkpoint_path, "saved_state.json")
|
| 392 |
if os.path.exists(saved_state_path):
|
|
@@ -397,33 +390,16 @@ def main():
|
|
| 397 |
global_step = saved_state.get("global_step", 0)
|
| 398 |
except:
|
| 399 |
pass
|
| 400 |
-
|
| 401 |
-
# 方法 3:从进度条保存的文件读取(如果有)
|
| 402 |
-
if global_step == 0:
|
| 403 |
-
progress_path = os.path.join(checkpoint_path, "progress.json")
|
| 404 |
-
if os.path.exists(progress_path):
|
| 405 |
-
try:
|
| 406 |
-
import json
|
| 407 |
-
with open(progress_path, "r") as f:
|
| 408 |
-
progress_data = json.load(f)
|
| 409 |
-
global_step = progress_data.get("step", 0)
|
| 410 |
-
except:
|
| 411 |
-
pass
|
| 412 |
-
|
| 413 |
-
logger.info(f"恢复的 Global step: {global_step}")
|
| 414 |
-
|
| 415 |
-
# 计算我们应该从哪个 epoch 和哪一步开始
|
| 416 |
first_epoch = global_step // num_update_steps_per_epoch
|
| 417 |
resume_step = global_step % num_update_steps_per_epoch
|
| 418 |
logger.info(f"Resuming from epoch {first_epoch}, step {resume_step}")
|
| 419 |
|
|
|
|
| 420 |
for epoch in range(first_epoch, args.num_train_epochs):
|
| 421 |
transformer.train()
|
| 422 |
-
|
| 423 |
-
# 定义一个累加器,用来计算平均 Loss
|
| 424 |
loss_accumulator = 0.0
|
| 425 |
|
| 426 |
-
# === 【数据加载器跳过逻辑】 ===
|
| 427 |
if args.resume_from_checkpoint and epoch == first_epoch:
|
| 428 |
train_dataloader_skip = accelerator.skip_first_batches(train_dataloader, resume_step * args.train_batch_size)
|
| 429 |
else:
|
|
@@ -431,40 +407,56 @@ def main():
|
|
| 431 |
|
| 432 |
for step, batch in enumerate(train_dataloader_skip):
|
| 433 |
with accelerator.accumulate(transformer):
|
| 434 |
-
# --- 数据准备和模型前向传播 (你的代码保持不变) ---
|
| 435 |
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
|
| 436 |
latents = latents * vae.config.scaling_factor
|
|
|
|
| 437 |
noise = torch.randn_like(latents)
|
| 438 |
bsz = latents.shape[0]
|
| 439 |
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device).long()
|
|
|
|
| 440 |
sigmas = timesteps.flatten() / noise_scheduler.config.num_train_timesteps
|
| 441 |
while sigmas.ndim < latents.ndim:
|
| 442 |
sigmas = sigmas.unsqueeze(-1)
|
|
|
|
| 443 |
noisy_latents = (1.0 - sigmas) * latents + sigmas * noise
|
| 444 |
target = noise - latents
|
|
|
|
|
|
|
| 445 |
noisy_latents = noisy_latents.unsqueeze(2)
|
|
|
|
| 446 |
encoder_hidden_states = batch["encoder_hidden_states"].to(dtype=weight_dtype)
|
|
|
|
| 447 |
forward_kwargs = {
|
| 448 |
input_arg: noisy_latents,
|
| 449 |
-
|
| 450 |
cond_arg: encoder_hidden_states,
|
| 451 |
"return_dict": False
|
| 452 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 453 |
model_pred = transformer(**forward_kwargs)[0]
|
| 454 |
|
| 455 |
-
# --- 各种形状修复 (你的代码保持不变) ---
|
| 456 |
if isinstance(model_pred, list): model_pred = model_pred[0]
|
| 457 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 458 |
if model_pred.shape != target.shape:
|
| 459 |
if model_pred.shape[0] == target.shape[1] and model_pred.shape[1] == target.shape[0]:
|
|
|
|
| 460 |
model_pred = model_pred.transpose(0, 1)
|
| 461 |
-
elif model_pred.numel() == target.numel():
|
| 462 |
-
model_pred = model_pred.reshape(target.shape)
|
| 463 |
|
| 464 |
-
# --- Loss 计算和累加 ---
|
| 465 |
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
| 466 |
-
|
| 467 |
-
# 【核心修改 1】:把每一小步的 loss 加到累加器里
|
| 468 |
loss_accumulator += loss.detach().item()
|
| 469 |
|
| 470 |
accelerator.backward(loss)
|
|
@@ -474,7 +466,6 @@ def main():
|
|
| 474 |
lr_scheduler.step()
|
| 475 |
optimizer.zero_grad()
|
| 476 |
|
| 477 |
-
# --- 模型更新后的操作 ---
|
| 478 |
if accelerator.sync_gradients:
|
| 479 |
global_step += 1
|
| 480 |
|
|
@@ -483,45 +474,32 @@ def main():
|
|
| 483 |
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
| 484 |
accelerator.save_state(save_path)
|
| 485 |
|
| 486 |
-
# 【核心修改 2】:只在模型更新时,打印平均 Loss
|
| 487 |
if accelerator.is_main_process:
|
| 488 |
-
# 计算过去 N 步的平均 Loss
|
| 489 |
avg_loss = loss_accumulator / args.gradient_accumulation_steps
|
| 490 |
-
|
| 491 |
-
# 打印格式化的结果
|
| 492 |
print(f"Steps: {global_step}/{args.max_train_steps} | Loss: {avg_loss:.4f}")
|
| 493 |
-
|
| 494 |
-
# 清零累加器,为下一轮做准备
|
| 495 |
loss_accumulator = 0.0
|
| 496 |
|
| 497 |
if global_step >= args.max_train_steps:
|
| 498 |
break
|
| 499 |
|
| 500 |
-
# === 【核心修改:这里加两行】 ===
|
| 501 |
-
# 如果步数到了,强制跳出外层的 Epoch 循环,否则它会空转
|
| 502 |
if global_step >= args.max_train_steps:
|
| 503 |
break
|
| 504 |
-
# ============================
|
| 505 |
|
| 506 |
-
#
|
| 507 |
if accelerator.is_main_process:
|
| 508 |
transformer = accelerator.unwrap_model(transformer)
|
| 509 |
transformer = transformer.to(torch.float32)
|
| 510 |
transformer_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(transformer))
|
| 511 |
|
| 512 |
-
# === 【核心修复】给参数名加上 "transformer." 前缀 ===
|
| 513 |
-
# 这一步非常关键,没有它,load_lora_weights 就认不出来
|
| 514 |
new_state_dict = {}
|
| 515 |
for k, v in transformer_lora_state_dict.items():
|
| 516 |
new_state_dict[f"transformer.{k}"] = v
|
| 517 |
-
# ================================================
|
| 518 |
|
| 519 |
-
# 使用 safetensors 保存加上前缀后的权重
|
| 520 |
save_path = os.path.join(args.output_dir, "pytorch_lora_weights.safetensors")
|
| 521 |
-
save_file(new_state_dict, save_path)
|
| 522 |
logger.info(f"Saved LoRA weights to {save_path}")
|
| 523 |
|
| 524 |
accelerator.end_training()
|
| 525 |
|
| 526 |
if __name__ == "__main__":
|
| 527 |
-
main()
|
|
|
|
| 4 |
import logging
|
| 5 |
import math
|
| 6 |
import os
|
| 7 |
+
from safetensors.torch import save_file
|
| 8 |
import random
|
| 9 |
import shutil
|
| 10 |
import glob
|
|
|
|
| 58 |
logger = get_logger(__name__, log_level="INFO")
|
| 59 |
|
| 60 |
def parse_args():
|
| 61 |
+
parser = argparse.ArgumentParser(description="Fixed Training script V3.")
|
| 62 |
parser.add_argument("--pretrained_model_name_or_path", type=str, required=True)
|
| 63 |
parser.add_argument("--revision", type=str, default=None)
|
| 64 |
parser.add_argument("--variant", type=str, default=None)
|
|
|
|
| 242 |
noise_scheduler = pipe.scheduler
|
| 243 |
del pipe
|
| 244 |
|
|
|
|
| 245 |
forward_signature = inspect.signature(transformer.forward)
|
| 246 |
params = forward_signature.parameters
|
| 247 |
param_names = list(params.keys())
|
|
|
|
| 251 |
if "x" in param_names: input_arg = "x"
|
| 252 |
elif "sample" in param_names: input_arg = "sample"
|
| 253 |
|
| 254 |
+
time_arg = "t"
|
| 255 |
+
if "timestep" in param_names: time_arg = "timestep"
|
| 256 |
+
|
| 257 |
cond_arg = "encoder_hidden_states"
|
| 258 |
if "cap_feats" in param_names: cond_arg = "cap_feats"
|
| 259 |
elif "context" in param_names: cond_arg = "context"
|
| 260 |
|
| 261 |
+
logger.info(f"Mapping: Input='{input_arg}', Time='{time_arg}', Cond='{cond_arg}'")
|
| 262 |
|
| 263 |
transformer.requires_grad_(False)
|
| 264 |
vae.requires_grad_(False)
|
|
|
|
| 362 |
transformer, optimizer, train_dataloader, lr_scheduler
|
| 363 |
)
|
| 364 |
|
|
|
|
| 365 |
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
| 366 |
|
|
|
|
|
|
|
| 367 |
global_step = 0
|
| 368 |
first_epoch = 0
|
| 369 |
resume_step = 0
|
| 370 |
|
| 371 |
if args.resume_from_checkpoint:
|
| 372 |
logger.info(f"Resuming from checkpoint {args.resume_from_checkpoint}")
|
|
|
|
|
|
|
| 373 |
accelerator.load_state(args.resume_from_checkpoint)
|
| 374 |
|
|
|
|
| 375 |
checkpoint_path = args.resume_from_checkpoint.rstrip('/')
|
|
|
|
|
|
|
| 376 |
folder_name = os.path.basename(checkpoint_path)
|
| 377 |
if folder_name.startswith("checkpoint-") and "-" in folder_name:
|
| 378 |
try:
|
|
|
|
| 380 |
except:
|
| 381 |
pass
|
| 382 |
|
|
|
|
| 383 |
if global_step == 0:
|
| 384 |
saved_state_path = os.path.join(checkpoint_path, "saved_state.json")
|
| 385 |
if os.path.exists(saved_state_path):
|
|
|
|
| 390 |
global_step = saved_state.get("global_step", 0)
|
| 391 |
except:
|
| 392 |
pass
|
| 393 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 394 |
first_epoch = global_step // num_update_steps_per_epoch
|
| 395 |
resume_step = global_step % num_update_steps_per_epoch
|
| 396 |
logger.info(f"Resuming from epoch {first_epoch}, step {resume_step}")
|
| 397 |
|
| 398 |
+
# 训练循环
|
| 399 |
for epoch in range(first_epoch, args.num_train_epochs):
|
| 400 |
transformer.train()
|
|
|
|
|
|
|
| 401 |
loss_accumulator = 0.0
|
| 402 |
|
|
|
|
| 403 |
if args.resume_from_checkpoint and epoch == first_epoch:
|
| 404 |
train_dataloader_skip = accelerator.skip_first_batches(train_dataloader, resume_step * args.train_batch_size)
|
| 405 |
else:
|
|
|
|
| 407 |
|
| 408 |
for step, batch in enumerate(train_dataloader_skip):
|
| 409 |
with accelerator.accumulate(transformer):
|
|
|
|
| 410 |
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
|
| 411 |
latents = latents * vae.config.scaling_factor
|
| 412 |
+
|
| 413 |
noise = torch.randn_like(latents)
|
| 414 |
bsz = latents.shape[0]
|
| 415 |
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device).long()
|
| 416 |
+
|
| 417 |
sigmas = timesteps.flatten() / noise_scheduler.config.num_train_timesteps
|
| 418 |
while sigmas.ndim < latents.ndim:
|
| 419 |
sigmas = sigmas.unsqueeze(-1)
|
| 420 |
+
|
| 421 |
noisy_latents = (1.0 - sigmas) * latents + sigmas * noise
|
| 422 |
target = noise - latents
|
| 423 |
+
|
| 424 |
+
# Z-Image 输入需要 5D
|
| 425 |
noisy_latents = noisy_latents.unsqueeze(2)
|
| 426 |
+
|
| 427 |
encoder_hidden_states = batch["encoder_hidden_states"].to(dtype=weight_dtype)
|
| 428 |
+
|
| 429 |
forward_kwargs = {
|
| 430 |
input_arg: noisy_latents,
|
| 431 |
+
time_arg: timesteps,
|
| 432 |
cond_arg: encoder_hidden_states,
|
| 433 |
"return_dict": False
|
| 434 |
}
|
| 435 |
+
|
| 436 |
+
if "pooled_projections" in param_names:
|
| 437 |
+
forward_kwargs["pooled_projections"] = torch.zeros(
|
| 438 |
+
(bsz, 2048), device=latents.device, dtype=weight_dtype
|
| 439 |
+
)
|
| 440 |
+
|
| 441 |
model_pred = transformer(**forward_kwargs)[0]
|
| 442 |
|
|
|
|
| 443 |
if isinstance(model_pred, list): model_pred = model_pred[0]
|
| 444 |
+
|
| 445 |
+
# === 【最后一步核心修复】处理维度颠倒 ===
|
| 446 |
+
|
| 447 |
+
# 1. 如果输出是 5 维的,先把那个 1 维压扁
|
| 448 |
+
if model_pred.ndim == 5:
|
| 449 |
+
model_pred = model_pred.squeeze(2)
|
| 450 |
+
|
| 451 |
+
# 2. 如果输出形状和 Target 形状颠倒了 (Channel <-> Batch),则转置回来
|
| 452 |
+
# Target: [1, 16, 128, 128] (Batch, Channel, H, W)
|
| 453 |
+
# Pred: [16, 1, 128, 128] (Channel, Batch, H, W) -> 需要修正
|
| 454 |
if model_pred.shape != target.shape:
|
| 455 |
if model_pred.shape[0] == target.shape[1] and model_pred.shape[1] == target.shape[0]:
|
| 456 |
+
# 执行转置,把 Channel 和 Batch 换回来
|
| 457 |
model_pred = model_pred.transpose(0, 1)
|
|
|
|
|
|
|
| 458 |
|
|
|
|
| 459 |
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
|
|
|
|
|
|
| 460 |
loss_accumulator += loss.detach().item()
|
| 461 |
|
| 462 |
accelerator.backward(loss)
|
|
|
|
| 466 |
lr_scheduler.step()
|
| 467 |
optimizer.zero_grad()
|
| 468 |
|
|
|
|
| 469 |
if accelerator.sync_gradients:
|
| 470 |
global_step += 1
|
| 471 |
|
|
|
|
| 474 |
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
| 475 |
accelerator.save_state(save_path)
|
| 476 |
|
|
|
|
| 477 |
if accelerator.is_main_process:
|
|
|
|
| 478 |
avg_loss = loss_accumulator / args.gradient_accumulation_steps
|
|
|
|
|
|
|
| 479 |
print(f"Steps: {global_step}/{args.max_train_steps} | Loss: {avg_loss:.4f}")
|
|
|
|
|
|
|
| 480 |
loss_accumulator = 0.0
|
| 481 |
|
| 482 |
if global_step >= args.max_train_steps:
|
| 483 |
break
|
| 484 |
|
|
|
|
|
|
|
| 485 |
if global_step >= args.max_train_steps:
|
| 486 |
break
|
|
|
|
| 487 |
|
| 488 |
+
# 保存权重
|
| 489 |
if accelerator.is_main_process:
|
| 490 |
transformer = accelerator.unwrap_model(transformer)
|
| 491 |
transformer = transformer.to(torch.float32)
|
| 492 |
transformer_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(transformer))
|
| 493 |
|
|
|
|
|
|
|
| 494 |
new_state_dict = {}
|
| 495 |
for k, v in transformer_lora_state_dict.items():
|
| 496 |
new_state_dict[f"transformer.{k}"] = v
|
|
|
|
| 497 |
|
|
|
|
| 498 |
save_path = os.path.join(args.output_dir, "pytorch_lora_weights.safetensors")
|
| 499 |
+
save_file(new_state_dict, save_path)
|
| 500 |
logger.info(f"Saved LoRA weights to {save_path}")
|
| 501 |
|
| 502 |
accelerator.end_training()
|
| 503 |
|
| 504 |
if __name__ == "__main__":
|
| 505 |
+
main()
|