Upload 6 files
Browse files- .gitattributes +2 -0
- fiefei_pic/2024-11-07_09-52-37_9988.png +3 -0
- fiefei_pic/2024-11-07_09-52-37_9988.txt +1 -0
- fiefei_pic/2024-11-07_09-53-47_4210.png +3 -0
- fiefei_pic/2024-11-07_09-53-47_4210.txt +1 -0
- run.sh +21 -0
- train_zimage_lora.py +475 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
fiefei_pic/2024-11-07_09-52-37_9988.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
fiefei_pic/2024-11-07_09-53-47_4210.png filter=lfs diff=lfs merge=lfs -text
|
fiefei_pic/2024-11-07_09-52-37_9988.png
ADDED
|
Git LFS Details
|
fiefei_pic/2024-11-07_09-52-37_9988.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
A close-up portrait of a jpop girl with long brown hair. She is wearing a white corset with a lace trim around the neckline. Her face is turned slightly to the left, and she is looking directly at the camera. The background is blurred, but it appears to be an indoor setting with a glass window and a plant.
|
fiefei_pic/2024-11-07_09-53-47_4210.png
ADDED
|
Git LFS Details
|
fiefei_pic/2024-11-07_09-53-47_4210.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
A close-up portrait of a jpop girl with long brown hair, wearing a red satin dress with a white lace trim. The background is blurred, but it appears to be an indoor setting. The woman is looking directly at the camera.
|
run.sh
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
# 检查是否安装了 bitsandbytes,如果没有则安装 (8-bit Adam 需要)
|
| 4 |
+
# pip install bitsandbytes > /dev/null 2>&1
|
| 5 |
+
|
| 6 |
+
accelerate launch --mixed_precision="bf16" train_zimage_lora.py \
|
| 7 |
+
--pretrained_model_name_or_path="../../../smodels/Z-Image-Turbo" \
|
| 8 |
+
--train_data_dir="../../../datasets/fiefei_pic" \
|
| 9 |
+
--resolution=1024 \
|
| 10 |
+
--train_batch_size=1 \
|
| 11 |
+
--gradient_accumulation_steps=4 \
|
| 12 |
+
--max_train_steps=1000 \
|
| 13 |
+
--learning_rate=1e-4 \
|
| 14 |
+
--mixed_precision="bf16" \
|
| 15 |
+
--output_dir="feifei-zimage-lora" \
|
| 16 |
+
--caption_column="text" \
|
| 17 |
+
--rank=64 \
|
| 18 |
+
--gradient_checkpointing \
|
| 19 |
+
--use_8bit_adam \
|
| 20 |
+
--checkpointing_steps=100 \
|
| 21 |
+
--checkpoints_total_limit=3
|
train_zimage_lora.py
ADDED
|
@@ -0,0 +1,475 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# coding=utf-8
|
| 3 |
+
import argparse
|
| 4 |
+
import logging
|
| 5 |
+
import math
|
| 6 |
+
import os
|
| 7 |
+
from safetensors.torch import save_file # Add this import
|
| 8 |
+
import random
|
| 9 |
+
import shutil
|
| 10 |
+
import glob
|
| 11 |
+
import gc
|
| 12 |
+
import inspect
|
| 13 |
+
from contextlib import nullcontext
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
|
| 16 |
+
import datasets
|
| 17 |
+
import numpy as np
|
| 18 |
+
import torch
|
| 19 |
+
import torch.nn.functional as F
|
| 20 |
+
import torch.utils.checkpoint
|
| 21 |
+
import transformers
|
| 22 |
+
from accelerate import Accelerator
|
| 23 |
+
from accelerate.logging import get_logger
|
| 24 |
+
from accelerate.utils import ProjectConfiguration, set_seed
|
| 25 |
+
from datasets import load_dataset, Dataset, DatasetDict
|
| 26 |
+
from huggingface_hub import create_repo, upload_folder
|
| 27 |
+
from packaging import version
|
| 28 |
+
from peft import LoraConfig
|
| 29 |
+
from peft.utils import get_peft_model_state_dict, set_peft_model_state_dict
|
| 30 |
+
from torchvision import transforms
|
| 31 |
+
from tqdm.auto import tqdm
|
| 32 |
+
|
| 33 |
+
from transformers import AutoTokenizer, AutoModel, BitsAndBytesConfig
|
| 34 |
+
|
| 35 |
+
import diffusers
|
| 36 |
+
from diffusers import (
|
| 37 |
+
AutoencoderKL,
|
| 38 |
+
DDPMScheduler,
|
| 39 |
+
DiffusionPipeline,
|
| 40 |
+
Transformer2DModel,
|
| 41 |
+
FlowMatchEulerDiscreteScheduler
|
| 42 |
+
)
|
| 43 |
+
from diffusers.optimization import get_scheduler
|
| 44 |
+
from diffusers.training_utils import cast_training_params, compute_snr
|
| 45 |
+
from diffusers.utils import (
|
| 46 |
+
check_min_version,
|
| 47 |
+
convert_state_dict_to_diffusers,
|
| 48 |
+
is_wandb_available,
|
| 49 |
+
)
|
| 50 |
+
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
|
| 51 |
+
from diffusers.utils.import_utils import is_xformers_available
|
| 52 |
+
|
| 53 |
+
if is_wandb_available():
|
| 54 |
+
import wandb
|
| 55 |
+
|
| 56 |
+
check_min_version("0.24.0")
|
| 57 |
+
|
| 58 |
+
logger = get_logger(__name__, log_level="INFO")
|
| 59 |
+
|
| 60 |
+
def parse_args():
|
| 61 |
+
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
| 62 |
+
parser.add_argument("--pretrained_model_name_or_path", type=str, required=True)
|
| 63 |
+
parser.add_argument("--revision", type=str, default=None)
|
| 64 |
+
parser.add_argument("--variant", type=str, default=None)
|
| 65 |
+
parser.add_argument("--dataset_name", type=str, default=None)
|
| 66 |
+
parser.add_argument("--dataset_config_name", type=str, default=None)
|
| 67 |
+
parser.add_argument("--train_data_dir", type=str, default=None)
|
| 68 |
+
parser.add_argument("--image_column", type=str, default="image")
|
| 69 |
+
parser.add_argument("--caption_column", type=str, default="text")
|
| 70 |
+
parser.add_argument("--validation_prompt", type=str, default=None)
|
| 71 |
+
parser.add_argument("--num_validation_images", type=int, default=4)
|
| 72 |
+
parser.add_argument("--validation_epochs", type=int, default=1)
|
| 73 |
+
parser.add_argument("--max_train_samples", type=int, default=None)
|
| 74 |
+
parser.add_argument("--output_dir", type=str, default="z-image-lora")
|
| 75 |
+
parser.add_argument("--cache_dir", type=str, default=None)
|
| 76 |
+
parser.add_argument("--seed", type=int, default=None)
|
| 77 |
+
parser.add_argument("--resolution", type=int, default=1024)
|
| 78 |
+
parser.add_argument("--center_crop", default=False, action="store_true")
|
| 79 |
+
parser.add_argument("--random_flip", action="store_true")
|
| 80 |
+
parser.add_argument("--train_batch_size", type=int, default=1)
|
| 81 |
+
parser.add_argument("--num_train_epochs", type=int, default=100)
|
| 82 |
+
parser.add_argument("--max_train_steps", type=int, default=None)
|
| 83 |
+
parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
|
| 84 |
+
parser.add_argument("--gradient_checkpointing", action="store_true")
|
| 85 |
+
parser.add_argument("--learning_rate", type=float, default=1e-4)
|
| 86 |
+
parser.add_argument("--scale_lr", action="store_true", default=False)
|
| 87 |
+
parser.add_argument("--lr_scheduler", type=str, default="constant")
|
| 88 |
+
parser.add_argument("--lr_warmup_steps", type=int, default=500)
|
| 89 |
+
parser.add_argument("--snr_gamma", type=float, default=None)
|
| 90 |
+
parser.add_argument("--use_8bit_adam", action="store_true")
|
| 91 |
+
parser.add_argument("--allow_tf32", action="store_true")
|
| 92 |
+
parser.add_argument("--dataloader_num_workers", type=int, default=0)
|
| 93 |
+
parser.add_argument("--adam_beta1", type=float, default=0.9)
|
| 94 |
+
parser.add_argument("--adam_beta2", type=float, default=0.999)
|
| 95 |
+
parser.add_argument("--adam_weight_decay", type=float, default=1e-2)
|
| 96 |
+
parser.add_argument("--adam_epsilon", type=float, default=1e-08)
|
| 97 |
+
parser.add_argument("--max_grad_norm", default=1.0, type=float)
|
| 98 |
+
parser.add_argument("--push_to_hub", action="store_true")
|
| 99 |
+
parser.add_argument("--hub_token", type=str, default=None)
|
| 100 |
+
parser.add_argument("--prediction_type", type=str, default=None)
|
| 101 |
+
parser.add_argument("--hub_model_id", type=str, default=None)
|
| 102 |
+
parser.add_argument("--logging_dir", type=str, default="logs")
|
| 103 |
+
parser.add_argument("--mixed_precision", type=str, default=None, choices=["no", "fp16", "bf16"])
|
| 104 |
+
parser.add_argument("--report_to", type=str, default="tensorboard")
|
| 105 |
+
parser.add_argument("--local_rank", type=int, default=-1)
|
| 106 |
+
parser.add_argument("--checkpointing_steps", type=int, default=500)
|
| 107 |
+
parser.add_argument("--checkpoints_total_limit", type=int, default=None)
|
| 108 |
+
parser.add_argument("--resume_from_checkpoint", type=str, default=None)
|
| 109 |
+
parser.add_argument("--enable_xformers_memory_efficient_attention", action="store_true")
|
| 110 |
+
parser.add_argument("--noise_offset", type=float, default=0)
|
| 111 |
+
parser.add_argument("--rank", type=int, default=4)
|
| 112 |
+
parser.add_argument("--image_interpolation_mode", type=str, default="lanczos")
|
| 113 |
+
|
| 114 |
+
args = parser.parse_args()
|
| 115 |
+
return args
|
| 116 |
+
|
| 117 |
+
def main():
|
| 118 |
+
torch.cuda.empty_cache()
|
| 119 |
+
gc.collect()
|
| 120 |
+
|
| 121 |
+
args = parse_args()
|
| 122 |
+
logging_dir = Path(args.output_dir, args.logging_dir)
|
| 123 |
+
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
|
| 124 |
+
|
| 125 |
+
accelerator = Accelerator(
|
| 126 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
| 127 |
+
mixed_precision=args.mixed_precision,
|
| 128 |
+
log_with=args.report_to,
|
| 129 |
+
project_config=accelerator_project_config,
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
logging.basicConfig(level=logging.INFO)
|
| 133 |
+
if args.seed is not None:
|
| 134 |
+
set_seed(args.seed)
|
| 135 |
+
|
| 136 |
+
if accelerator.is_main_process:
|
| 137 |
+
if args.output_dir is not None:
|
| 138 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 139 |
+
|
| 140 |
+
weight_dtype = torch.float32
|
| 141 |
+
if accelerator.mixed_precision == "fp16":
|
| 142 |
+
weight_dtype = torch.float16
|
| 143 |
+
elif accelerator.mixed_precision == "bf16":
|
| 144 |
+
weight_dtype = torch.bfloat16
|
| 145 |
+
|
| 146 |
+
# ========================================================================
|
| 147 |
+
# PHASE 1: Text Embedding Pre-computation
|
| 148 |
+
# ========================================================================
|
| 149 |
+
logger.info("PHASE 1: Text Embedding Pre-computation")
|
| 150 |
+
|
| 151 |
+
data_entries = []
|
| 152 |
+
image_files = []
|
| 153 |
+
for ext in ["*.png", "*.jpg", "*.jpeg", "*.webp"]:
|
| 154 |
+
image_files.extend(glob.glob(os.path.join(args.train_data_dir, "**", ext), recursive=True))
|
| 155 |
+
image_files.extend(glob.glob(os.path.join(args.train_data_dir, "**", ext.upper()), recursive=True))
|
| 156 |
+
|
| 157 |
+
for img_path in image_files:
|
| 158 |
+
txt_path = os.path.splitext(img_path)[0] + ".txt"
|
| 159 |
+
if os.path.exists(txt_path):
|
| 160 |
+
with open(txt_path, "r", encoding="utf-8") as f:
|
| 161 |
+
caption = f.read().strip()
|
| 162 |
+
if caption:
|
| 163 |
+
data_entries.append({"image_path": img_path, "caption": caption})
|
| 164 |
+
|
| 165 |
+
if not data_entries:
|
| 166 |
+
raise ValueError("No images found!")
|
| 167 |
+
|
| 168 |
+
quantization_config = BitsAndBytesConfig(
|
| 169 |
+
load_in_4bit=True,
|
| 170 |
+
bnb_4bit_quant_type="nf4",
|
| 171 |
+
bnb_4bit_compute_dtype=weight_dtype,
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 175 |
+
args.pretrained_model_name_or_path,
|
| 176 |
+
subfolder="tokenizer",
|
| 177 |
+
revision=args.revision,
|
| 178 |
+
trust_remote_code=True
|
| 179 |
+
)
|
| 180 |
+
if hasattr(tokenizer, "pad_token") and tokenizer.pad_token is None:
|
| 181 |
+
if hasattr(tokenizer, "eos_token"):
|
| 182 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 183 |
+
else:
|
| 184 |
+
tokenizer.add_special_tokens({'pad_token': '<pad>'})
|
| 185 |
+
|
| 186 |
+
text_encoder = AutoModel.from_pretrained(
|
| 187 |
+
args.pretrained_model_name_or_path,
|
| 188 |
+
subfolder="text_encoder",
|
| 189 |
+
quantization_config=quantization_config,
|
| 190 |
+
trust_remote_code=True,
|
| 191 |
+
revision=args.revision,
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
precomputed_dataset = []
|
| 195 |
+
batch_size = 1
|
| 196 |
+
|
| 197 |
+
logger.info(f"Encoding {len(data_entries)} captions...")
|
| 198 |
+
for i in tqdm(range(0, len(data_entries), batch_size), desc="Encoding"):
|
| 199 |
+
batch = data_entries[i : i + batch_size]
|
| 200 |
+
captions = [x["caption"] for x in batch]
|
| 201 |
+
|
| 202 |
+
inputs = tokenizer(
|
| 203 |
+
captions,
|
| 204 |
+
max_length=512,
|
| 205 |
+
padding="max_length",
|
| 206 |
+
truncation=True,
|
| 207 |
+
return_tensors="pt"
|
| 208 |
+
)
|
| 209 |
+
input_ids = inputs.input_ids.to(accelerator.device)
|
| 210 |
+
|
| 211 |
+
with torch.no_grad():
|
| 212 |
+
encoder_outputs = text_encoder(input_ids, output_hidden_states=True, return_dict=True)
|
| 213 |
+
embeddings = encoder_outputs.last_hidden_state.to(dtype=weight_dtype).cpu()
|
| 214 |
+
|
| 215 |
+
for j, item in enumerate(batch):
|
| 216 |
+
precomputed_dataset.append({
|
| 217 |
+
"image": item["image_path"],
|
| 218 |
+
"encoder_hidden_states": embeddings[j]
|
| 219 |
+
})
|
| 220 |
+
|
| 221 |
+
del text_encoder
|
| 222 |
+
del tokenizer
|
| 223 |
+
torch.cuda.empty_cache()
|
| 224 |
+
gc.collect()
|
| 225 |
+
|
| 226 |
+
# ========================================================================
|
| 227 |
+
# PHASE 2: Model Loading
|
| 228 |
+
# ========================================================================
|
| 229 |
+
logger.info("PHASE 2: Loading Transformer...")
|
| 230 |
+
|
| 231 |
+
pipe = DiffusionPipeline.from_pretrained(
|
| 232 |
+
args.pretrained_model_name_or_path,
|
| 233 |
+
trust_remote_code=True,
|
| 234 |
+
revision=args.revision,
|
| 235 |
+
variant=args.variant,
|
| 236 |
+
text_encoder=None,
|
| 237 |
+
tokenizer=None
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
transformer = pipe.transformer
|
| 241 |
+
vae = pipe.vae
|
| 242 |
+
noise_scheduler = pipe.scheduler
|
| 243 |
+
del pipe
|
| 244 |
+
|
| 245 |
+
# Analyze Params
|
| 246 |
+
forward_signature = inspect.signature(transformer.forward)
|
| 247 |
+
params = forward_signature.parameters
|
| 248 |
+
param_names = list(params.keys())
|
| 249 |
+
logger.info(f"Detected Transformer forward params: {param_names}")
|
| 250 |
+
|
| 251 |
+
input_arg = "hidden_states"
|
| 252 |
+
if "x" in param_names: input_arg = "x"
|
| 253 |
+
elif "sample" in param_names: input_arg = "sample"
|
| 254 |
+
|
| 255 |
+
cond_arg = "encoder_hidden_states"
|
| 256 |
+
if "cap_feats" in param_names: cond_arg = "cap_feats"
|
| 257 |
+
elif "context" in param_names: cond_arg = "context"
|
| 258 |
+
|
| 259 |
+
logger.info(f"Mapping: Input='{input_arg}', Cond='{cond_arg}'")
|
| 260 |
+
|
| 261 |
+
transformer.requires_grad_(False)
|
| 262 |
+
vae.requires_grad_(False)
|
| 263 |
+
|
| 264 |
+
transformer.to(accelerator.device, dtype=weight_dtype)
|
| 265 |
+
vae.to(accelerator.device, dtype=weight_dtype)
|
| 266 |
+
|
| 267 |
+
transformer_lora_config = LoraConfig(
|
| 268 |
+
r=args.rank,
|
| 269 |
+
lora_alpha=args.rank,
|
| 270 |
+
init_lora_weights="gaussian",
|
| 271 |
+
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
|
| 272 |
+
)
|
| 273 |
+
transformer.add_adapter(transformer_lora_config)
|
| 274 |
+
|
| 275 |
+
if args.mixed_precision == "fp16":
|
| 276 |
+
cast_training_params(transformer, dtype=torch.float32)
|
| 277 |
+
|
| 278 |
+
if args.enable_xformers_memory_efficient_attention:
|
| 279 |
+
if is_xformers_available():
|
| 280 |
+
if hasattr(transformer, "enable_xformers_memory_efficient_attention"):
|
| 281 |
+
transformer.enable_xformers_memory_efficient_attention()
|
| 282 |
+
|
| 283 |
+
if args.gradient_checkpointing:
|
| 284 |
+
transformer.enable_gradient_checkpointing()
|
| 285 |
+
|
| 286 |
+
if args.scale_lr:
|
| 287 |
+
args.learning_rate = (
|
| 288 |
+
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
if args.use_8bit_adam:
|
| 292 |
+
import bitsandbytes as bnb
|
| 293 |
+
optimizer_cls = bnb.optim.AdamW8bit
|
| 294 |
+
else:
|
| 295 |
+
optimizer_cls = torch.optim.AdamW
|
| 296 |
+
|
| 297 |
+
lora_layers = filter(lambda p: p.requires_grad, transformer.parameters())
|
| 298 |
+
optimizer = optimizer_cls(
|
| 299 |
+
lora_layers,
|
| 300 |
+
lr=args.learning_rate,
|
| 301 |
+
betas=(args.adam_beta1, args.adam_beta2),
|
| 302 |
+
weight_decay=args.adam_weight_decay,
|
| 303 |
+
eps=args.adam_epsilon,
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
dataset = Dataset.from_list(precomputed_dataset).cast_column("image", datasets.Image())
|
| 307 |
+
dataset = DatasetDict({"train": dataset})
|
| 308 |
+
|
| 309 |
+
interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper(), None)
|
| 310 |
+
train_transforms = transforms.Compose(
|
| 311 |
+
[
|
| 312 |
+
transforms.Resize(args.resolution, interpolation=interpolation),
|
| 313 |
+
transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
|
| 314 |
+
transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
|
| 315 |
+
transforms.ToTensor(),
|
| 316 |
+
transforms.Normalize([0.5], [0.5]),
|
| 317 |
+
]
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
def preprocess_train(examples):
|
| 321 |
+
images = [image.convert("RGB") for image in examples["image"]]
|
| 322 |
+
examples["pixel_values"] = [train_transforms(image) for image in images]
|
| 323 |
+
return examples
|
| 324 |
+
|
| 325 |
+
with accelerator.main_process_first():
|
| 326 |
+
if args.max_train_samples is not None:
|
| 327 |
+
dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
|
| 328 |
+
train_dataset = dataset["train"].with_transform(preprocess_train)
|
| 329 |
+
|
| 330 |
+
def collate_fn(examples):
|
| 331 |
+
pixel_values = torch.stack([example["pixel_values"] for example in examples])
|
| 332 |
+
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
|
| 333 |
+
encoder_hidden_states = torch.stack([torch.tensor(example["encoder_hidden_states"]) for example in examples])
|
| 334 |
+
return {"pixel_values": pixel_values, "encoder_hidden_states": encoder_hidden_states}
|
| 335 |
+
|
| 336 |
+
train_dataloader = torch.utils.data.DataLoader(
|
| 337 |
+
train_dataset,
|
| 338 |
+
shuffle=True,
|
| 339 |
+
collate_fn=collate_fn,
|
| 340 |
+
batch_size=args.train_batch_size,
|
| 341 |
+
num_workers=args.dataloader_num_workers,
|
| 342 |
+
)
|
| 343 |
+
|
| 344 |
+
num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
|
| 345 |
+
if args.max_train_steps is None:
|
| 346 |
+
len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
|
| 347 |
+
num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
|
| 348 |
+
num_training_steps_for_scheduler = args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes
|
| 349 |
+
else:
|
| 350 |
+
num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
|
| 351 |
+
|
| 352 |
+
lr_scheduler = get_scheduler(
|
| 353 |
+
args.lr_scheduler,
|
| 354 |
+
optimizer=optimizer,
|
| 355 |
+
num_warmup_steps=num_warmup_steps_for_scheduler,
|
| 356 |
+
num_training_steps=num_training_steps_for_scheduler,
|
| 357 |
+
)
|
| 358 |
+
|
| 359 |
+
transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
| 360 |
+
transformer, optimizer, train_dataloader, lr_scheduler
|
| 361 |
+
)
|
| 362 |
+
|
| 363 |
+
logger.info("***** Running training (Final Fix) *****")
|
| 364 |
+
global_step = 0
|
| 365 |
+
first_epoch = 0
|
| 366 |
+
|
| 367 |
+
progress_bar = tqdm(range(0, args.max_train_steps), initial=0, desc="Steps", disable=not accelerator.is_local_main_process)
|
| 368 |
+
|
| 369 |
+
for epoch in range(first_epoch, args.num_train_epochs):
|
| 370 |
+
transformer.train()
|
| 371 |
+
train_loss = 0.0
|
| 372 |
+
for step, batch in enumerate(train_dataloader):
|
| 373 |
+
with accelerator.accumulate(transformer):
|
| 374 |
+
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
|
| 375 |
+
latents = latents * vae.config.scaling_factor
|
| 376 |
+
noise = torch.randn_like(latents)
|
| 377 |
+
bsz = latents.shape[0]
|
| 378 |
+
|
| 379 |
+
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device).long()
|
| 380 |
+
|
| 381 |
+
sigmas = timesteps.flatten() / noise_scheduler.config.num_train_timesteps
|
| 382 |
+
while sigmas.ndim < latents.ndim:
|
| 383 |
+
sigmas = sigmas.unsqueeze(-1)
|
| 384 |
+
|
| 385 |
+
noisy_latents = (1.0 - sigmas) * latents + sigmas * noise
|
| 386 |
+
target = noise - latents
|
| 387 |
+
|
| 388 |
+
# <--- 核心修复:Unsqueeze to (Batch, Channels, Frames=1, H, W) --->
|
| 389 |
+
noisy_latents = noisy_latents.unsqueeze(2)
|
| 390 |
+
|
| 391 |
+
encoder_hidden_states = batch["encoder_hidden_states"].to(dtype=weight_dtype)
|
| 392 |
+
|
| 393 |
+
forward_kwargs = {
|
| 394 |
+
input_arg: noisy_latents,
|
| 395 |
+
"t": timesteps, # Ensure t is passed
|
| 396 |
+
cond_arg: encoder_hidden_states,
|
| 397 |
+
"return_dict": False
|
| 398 |
+
}
|
| 399 |
+
|
| 400 |
+
model_pred = transformer(**forward_kwargs)[0]
|
| 401 |
+
|
| 402 |
+
# 1. Handle case where output is a list (fixes AttributeError: 'list' object has no attribute 'ndim')
|
| 403 |
+
if isinstance(model_pred, list):
|
| 404 |
+
model_pred = model_pred[0]
|
| 405 |
+
|
| 406 |
+
# 2. Handle 5D output (Video/Motion modules)
|
| 407 |
+
if model_pred.ndim == 5:
|
| 408 |
+
model_pred = model_pred.squeeze(2)
|
| 409 |
+
|
| 410 |
+
# 3. Fix Shape Mismatch: Target [B, C, H, W] vs Pred [C, B, H, W] or similar
|
| 411 |
+
# The error showed Target [1, 16, 128, 128] vs Pred [16, 1, 128, 128]
|
| 412 |
+
if model_pred.shape != target.shape:
|
| 413 |
+
# If dimensions are just swapped (e.g., [16, 1] vs [1, 16]), transpose them
|
| 414 |
+
if model_pred.shape[0] == target.shape[1] and model_pred.shape[1] == target.shape[0]:
|
| 415 |
+
model_pred = model_pred.transpose(0, 1)
|
| 416 |
+
# If shapes still don't match but have same number of elements, force reshape
|
| 417 |
+
elif model_pred.numel() == target.numel():
|
| 418 |
+
model_pred = model_pred.reshape(target.shape)
|
| 419 |
+
|
| 420 |
+
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
| 421 |
+
avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
|
| 422 |
+
train_loss += avg_loss.item() / args.gradient_accumulation_steps
|
| 423 |
+
|
| 424 |
+
accelerator.backward(loss)
|
| 425 |
+
if accelerator.sync_gradients:
|
| 426 |
+
accelerator.clip_grad_norm_(lora_layers, args.max_grad_norm)
|
| 427 |
+
optimizer.step()
|
| 428 |
+
lr_scheduler.step()
|
| 429 |
+
optimizer.zero_grad()
|
| 430 |
+
|
| 431 |
+
if accelerator.sync_gradients:
|
| 432 |
+
progress_bar.update(1)
|
| 433 |
+
global_step += 1
|
| 434 |
+
accelerator.log({"train_loss": train_loss}, step=global_step)
|
| 435 |
+
train_loss = 0.0
|
| 436 |
+
|
| 437 |
+
if global_step % args.checkpointing_steps == 0:
|
| 438 |
+
if accelerator.is_main_process:
|
| 439 |
+
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
| 440 |
+
accelerator.save_state(save_path)
|
| 441 |
+
|
| 442 |
+
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
|
| 443 |
+
progress_bar.set_postfix(**logs)
|
| 444 |
+
|
| 445 |
+
if global_step >= args.max_train_steps:
|
| 446 |
+
break
|
| 447 |
+
|
| 448 |
+
# === 【核心修改:这里加两行】 ===
|
| 449 |
+
# 如果步数到了,强制跳出外层的 Epoch 循环,否则它会空转
|
| 450 |
+
if global_step >= args.max_train_steps:
|
| 451 |
+
break
|
| 452 |
+
# ============================
|
| 453 |
+
|
| 454 |
+
# === 这里的缩进是 4 个空格 ===
|
| 455 |
+
if accelerator.is_main_process:
|
| 456 |
+
transformer = accelerator.unwrap_model(transformer)
|
| 457 |
+
transformer = transformer.to(torch.float32)
|
| 458 |
+
transformer_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(transformer))
|
| 459 |
+
|
| 460 |
+
# === 【核心修复】给参数名加上 "transformer." 前缀 ===
|
| 461 |
+
# 这一步非常关键,没有它,load_lora_weights 就认不出来
|
| 462 |
+
new_state_dict = {}
|
| 463 |
+
for k, v in transformer_lora_state_dict.items():
|
| 464 |
+
new_state_dict[f"transformer.{k}"] = v
|
| 465 |
+
# ================================================
|
| 466 |
+
|
| 467 |
+
# 使用 safetensors 保存加上前缀后的权重
|
| 468 |
+
save_path = os.path.join(args.output_dir, "pytorch_lora_weights.safetensors")
|
| 469 |
+
save_file(new_state_dict, save_path) # 注意这里传的是 new_state_dict
|
| 470 |
+
logger.info(f"Saved LoRA weights to {save_path}")
|
| 471 |
+
|
| 472 |
+
accelerator.end_training()
|
| 473 |
+
|
| 474 |
+
if __name__ == "__main__":
|
| 475 |
+
main()
|