#!/usr/bin/env python # coding=utf-8 import argparse import logging import math import os from safetensors.torch import save_file import random import shutil import glob import gc import inspect from contextlib import nullcontext from pathlib import Path import datasets import numpy as np import torch import torch.nn.functional as F import torch.utils.checkpoint import transformers from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import ProjectConfiguration, set_seed from datasets import load_dataset, Dataset, DatasetDict from huggingface_hub import create_repo, upload_folder from packaging import version from peft import LoraConfig from peft.utils import get_peft_model_state_dict, set_peft_model_state_dict from torchvision import transforms from tqdm.auto import tqdm from transformers import AutoTokenizer, AutoModel, BitsAndBytesConfig import diffusers from diffusers import ( AutoencoderKL, DDPMScheduler, DiffusionPipeline, Transformer2DModel, FlowMatchEulerDiscreteScheduler ) from diffusers.optimization import get_scheduler from diffusers.training_utils import cast_training_params, compute_snr from diffusers.utils import ( check_min_version, convert_state_dict_to_diffusers, is_wandb_available, ) from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card from diffusers.utils.import_utils import is_xformers_available if is_wandb_available(): import wandb check_min_version("0.24.0") logger = get_logger(__name__, log_level="INFO") def parse_args(): parser = argparse.ArgumentParser(description="Fixed Training script V3.") parser.add_argument("--pretrained_model_name_or_path", type=str, required=True) parser.add_argument("--revision", type=str, default=None) parser.add_argument("--variant", type=str, default=None) parser.add_argument("--dataset_name", type=str, default=None) parser.add_argument("--dataset_config_name", type=str, default=None) parser.add_argument("--train_data_dir", type=str, default=None) parser.add_argument("--image_column", type=str, default="image") parser.add_argument("--caption_column", type=str, default="text") parser.add_argument("--validation_prompt", type=str, default=None) parser.add_argument("--num_validation_images", type=int, default=4) parser.add_argument("--validation_epochs", type=int, default=1) parser.add_argument("--max_train_samples", type=int, default=None) parser.add_argument("--output_dir", type=str, default="z-image-lora") parser.add_argument("--cache_dir", type=str, default=None) parser.add_argument("--seed", type=int, default=None) parser.add_argument("--resolution", type=int, default=1024) parser.add_argument("--center_crop", default=False, action="store_true") parser.add_argument("--random_flip", action="store_true") parser.add_argument("--train_batch_size", type=int, default=1) parser.add_argument("--num_train_epochs", type=int, default=100) parser.add_argument("--max_train_steps", type=int, default=None) parser.add_argument("--gradient_accumulation_steps", type=int, default=1) parser.add_argument("--gradient_checkpointing", action="store_true") parser.add_argument("--learning_rate", type=float, default=1e-4) parser.add_argument("--scale_lr", action="store_true", default=False) parser.add_argument("--lr_scheduler", type=str, default="constant") parser.add_argument("--lr_warmup_steps", type=int, default=500) parser.add_argument("--snr_gamma", type=float, default=None) parser.add_argument("--use_8bit_adam", action="store_true") parser.add_argument("--allow_tf32", action="store_true") parser.add_argument("--dataloader_num_workers", type=int, default=0) parser.add_argument("--adam_beta1", type=float, default=0.9) parser.add_argument("--adam_beta2", type=float, default=0.999) parser.add_argument("--adam_weight_decay", type=float, default=1e-2) parser.add_argument("--adam_epsilon", type=float, default=1e-08) parser.add_argument("--max_grad_norm", default=1.0, type=float) parser.add_argument("--push_to_hub", action="store_true") parser.add_argument("--hub_token", type=str, default=None) parser.add_argument("--prediction_type", type=str, default=None) parser.add_argument("--hub_model_id", type=str, default=None) parser.add_argument("--logging_dir", type=str, default="logs") parser.add_argument("--mixed_precision", type=str, default=None, choices=["no", "fp16", "bf16"]) parser.add_argument("--report_to", type=str, default="tensorboard") parser.add_argument("--local_rank", type=int, default=-1) parser.add_argument("--checkpointing_steps", type=int, default=500) parser.add_argument("--checkpoints_total_limit", type=int, default=None) parser.add_argument("--resume_from_checkpoint", type=str, default=None) parser.add_argument("--enable_xformers_memory_efficient_attention", action="store_true") parser.add_argument("--noise_offset", type=float, default=0) parser.add_argument("--rank", type=int, default=4) parser.add_argument("--image_interpolation_mode", type=str, default="lanczos") args = parser.parse_args() return args def main(): torch.cuda.empty_cache() gc.collect() args = parse_args() logging_dir = Path(args.output_dir, args.logging_dir) accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) accelerator = Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision, log_with=args.report_to, project_config=accelerator_project_config, ) logging.basicConfig(level=logging.INFO) if args.seed is not None: set_seed(args.seed) if accelerator.is_main_process: if args.output_dir is not None: os.makedirs(args.output_dir, exist_ok=True) weight_dtype = torch.float32 if accelerator.mixed_precision == "fp16": weight_dtype = torch.float16 elif accelerator.mixed_precision == "bf16": weight_dtype = torch.bfloat16 # ======================================================================== # PHASE 1: Text Embedding Pre-computation # ======================================================================== logger.info("PHASE 1: Text Embedding Pre-computation") data_entries = [] image_files = [] for ext in ["*.png", "*.jpg", "*.jpeg", "*.webp"]: image_files.extend(glob.glob(os.path.join(args.train_data_dir, "**", ext), recursive=True)) image_files.extend(glob.glob(os.path.join(args.train_data_dir, "**", ext.upper()), recursive=True)) for img_path in image_files: txt_path = os.path.splitext(img_path)[0] + ".txt" if os.path.exists(txt_path): with open(txt_path, "r", encoding="utf-8") as f: caption = f.read().strip() if caption: data_entries.append({"image_path": img_path, "caption": caption}) if not data_entries: raise ValueError("No images found!") quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=weight_dtype, ) tokenizer = AutoTokenizer.from_pretrained( args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, trust_remote_code=True ) if hasattr(tokenizer, "pad_token") and tokenizer.pad_token is None: if hasattr(tokenizer, "eos_token"): tokenizer.pad_token = tokenizer.eos_token else: tokenizer.add_special_tokens({'pad_token': ''}) text_encoder = AutoModel.from_pretrained( args.pretrained_model_name_or_path, subfolder="text_encoder", quantization_config=quantization_config, trust_remote_code=True, revision=args.revision, ) precomputed_dataset = [] batch_size = 1 logger.info(f"Encoding {len(data_entries)} captions...") for i in tqdm(range(0, len(data_entries), batch_size), desc="Encoding"): batch = data_entries[i : i + batch_size] captions = [x["caption"] for x in batch] inputs = tokenizer( captions, max_length=512, padding="max_length", truncation=True, return_tensors="pt" ) input_ids = inputs.input_ids.to(accelerator.device) with torch.no_grad(): encoder_outputs = text_encoder(input_ids, output_hidden_states=True, return_dict=True) embeddings = encoder_outputs.last_hidden_state.to(dtype=weight_dtype).cpu() for j, item in enumerate(batch): precomputed_dataset.append({ "image": item["image_path"], "encoder_hidden_states": embeddings[j] }) del text_encoder del tokenizer torch.cuda.empty_cache() gc.collect() # ======================================================================== # PHASE 2: Model Loading # ======================================================================== logger.info("PHASE 2: Loading Transformer...") pipe = DiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, trust_remote_code=True, revision=args.revision, variant=args.variant, text_encoder=None, tokenizer=None ) transformer = pipe.transformer vae = pipe.vae noise_scheduler = pipe.scheduler del pipe forward_signature = inspect.signature(transformer.forward) params = forward_signature.parameters param_names = list(params.keys()) logger.info(f"Detected Transformer forward params: {param_names}") input_arg = "hidden_states" if "x" in param_names: input_arg = "x" elif "sample" in param_names: input_arg = "sample" time_arg = "t" if "timestep" in param_names: time_arg = "timestep" cond_arg = "encoder_hidden_states" if "cap_feats" in param_names: cond_arg = "cap_feats" elif "context" in param_names: cond_arg = "context" logger.info(f"Mapping: Input='{input_arg}', Time='{time_arg}', Cond='{cond_arg}'") transformer.requires_grad_(False) vae.requires_grad_(False) transformer.to(accelerator.device, dtype=weight_dtype) vae.to(accelerator.device, dtype=weight_dtype) transformer_lora_config = LoraConfig( r=args.rank, lora_alpha=args.rank, init_lora_weights="gaussian", target_modules=["to_k", "to_q", "to_v", "to_out.0"], ) transformer.add_adapter(transformer_lora_config) if args.mixed_precision == "fp16": cast_training_params(transformer, dtype=torch.float32) if args.enable_xformers_memory_efficient_attention: if is_xformers_available(): if hasattr(transformer, "enable_xformers_memory_efficient_attention"): transformer.enable_xformers_memory_efficient_attention() if args.gradient_checkpointing: transformer.enable_gradient_checkpointing() if args.scale_lr: args.learning_rate = ( args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes ) if args.use_8bit_adam: import bitsandbytes as bnb optimizer_cls = bnb.optim.AdamW8bit else: optimizer_cls = torch.optim.AdamW lora_layers = filter(lambda p: p.requires_grad, transformer.parameters()) optimizer = optimizer_cls( lora_layers, lr=args.learning_rate, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay, eps=args.adam_epsilon, ) dataset = Dataset.from_list(precomputed_dataset).cast_column("image", datasets.Image()) dataset = DatasetDict({"train": dataset}) interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper(), None) train_transforms = transforms.Compose( [ transforms.Resize(args.resolution, interpolation=interpolation), transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution), transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ] ) def preprocess_train(examples): images = [image.convert("RGB") for image in examples["image"]] examples["pixel_values"] = [train_transforms(image) for image in images] return examples with accelerator.main_process_first(): if args.max_train_samples is not None: dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples)) train_dataset = dataset["train"].with_transform(preprocess_train) def collate_fn(examples): pixel_values = torch.stack([example["pixel_values"] for example in examples]) pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() encoder_hidden_states = torch.stack([torch.tensor(example["encoder_hidden_states"]) for example in examples]) return {"pixel_values": pixel_values, "encoder_hidden_states": encoder_hidden_states} train_dataloader = torch.utils.data.DataLoader( train_dataset, shuffle=True, collate_fn=collate_fn, batch_size=args.train_batch_size, num_workers=args.dataloader_num_workers, ) num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes if args.max_train_steps is None: len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes) num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps) num_training_steps_for_scheduler = args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes else: num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes lr_scheduler = get_scheduler( args.lr_scheduler, optimizer=optimizer, num_warmup_steps=num_warmup_steps_for_scheduler, num_training_steps=num_training_steps_for_scheduler, ) transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( transformer, optimizer, train_dataloader, lr_scheduler ) num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) global_step = 0 first_epoch = 0 resume_step = 0 if args.resume_from_checkpoint: logger.info(f"Resuming from checkpoint {args.resume_from_checkpoint}") accelerator.load_state(args.resume_from_checkpoint) checkpoint_path = args.resume_from_checkpoint.rstrip('/') folder_name = os.path.basename(checkpoint_path) if folder_name.startswith("checkpoint-") and "-" in folder_name: try: global_step = int(folder_name.split("-")[1]) except: pass if global_step == 0: saved_state_path = os.path.join(checkpoint_path, "saved_state.json") if os.path.exists(saved_state_path): try: import json with open(saved_state_path, "r") as f: saved_state = json.load(f) global_step = saved_state.get("global_step", 0) except: pass first_epoch = global_step // num_update_steps_per_epoch resume_step = global_step % num_update_steps_per_epoch logger.info(f"Resuming from epoch {first_epoch}, step {resume_step}") # 训练循环 for epoch in range(first_epoch, args.num_train_epochs): transformer.train() loss_accumulator = 0.0 if args.resume_from_checkpoint and epoch == first_epoch: train_dataloader_skip = accelerator.skip_first_batches(train_dataloader, resume_step * args.train_batch_size) else: train_dataloader_skip = train_dataloader for step, batch in enumerate(train_dataloader_skip): with accelerator.accumulate(transformer): latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() latents = latents * vae.config.scaling_factor noise = torch.randn_like(latents) bsz = latents.shape[0] timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device).long() sigmas = timesteps.flatten() / noise_scheduler.config.num_train_timesteps while sigmas.ndim < latents.ndim: sigmas = sigmas.unsqueeze(-1) noisy_latents = (1.0 - sigmas) * latents + sigmas * noise target = noise - latents # Z-Image 输入需要 5D noisy_latents = noisy_latents.unsqueeze(2) encoder_hidden_states = batch["encoder_hidden_states"].to(dtype=weight_dtype) forward_kwargs = { input_arg: noisy_latents, time_arg: timesteps, cond_arg: encoder_hidden_states, "return_dict": False } if "pooled_projections" in param_names: forward_kwargs["pooled_projections"] = torch.zeros( (bsz, 2048), device=latents.device, dtype=weight_dtype ) model_pred = transformer(**forward_kwargs)[0] if isinstance(model_pred, list): model_pred = model_pred[0] # === 【最后一步核心修复】处理维度颠倒 === # 1. 如果输出是 5 维的,先把那个 1 维压扁 if model_pred.ndim == 5: model_pred = model_pred.squeeze(2) # 2. 如果输出形状和 Target 形状颠倒了 (Channel <-> Batch),则转置回来 # Target: [1, 16, 128, 128] (Batch, Channel, H, W) # Pred: [16, 1, 128, 128] (Channel, Batch, H, W) -> 需要修正 if model_pred.shape != target.shape: if model_pred.shape[0] == target.shape[1] and model_pred.shape[1] == target.shape[0]: # 执行转置,把 Channel 和 Batch 换回来 model_pred = model_pred.transpose(0, 1) loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") loss_accumulator += loss.detach().item() accelerator.backward(loss) if accelerator.sync_gradients: accelerator.clip_grad_norm_(lora_layers, args.max_grad_norm) optimizer.step() lr_scheduler.step() optimizer.zero_grad() if accelerator.sync_gradients: global_step += 1 if global_step % args.checkpointing_steps == 0: if accelerator.is_main_process: save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") accelerator.save_state(save_path) if accelerator.is_main_process: avg_loss = loss_accumulator / args.gradient_accumulation_steps print(f"Steps: {global_step}/{args.max_train_steps} | Loss: {avg_loss:.4f}") loss_accumulator = 0.0 if global_step >= args.max_train_steps: break if global_step >= args.max_train_steps: break # 保存权重 if accelerator.is_main_process: transformer = accelerator.unwrap_model(transformer) transformer = transformer.to(torch.float32) transformer_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(transformer)) new_state_dict = {} for k, v in transformer_lora_state_dict.items(): new_state_dict[f"transformer.{k}"] = v save_path = os.path.join(args.output_dir, "pytorch_lora_weights.safetensors") save_file(new_state_dict, save_path) logger.info(f"Saved LoRA weights to {save_path}") accelerator.end_training() if __name__ == "__main__": main()