aifeifei798 commited on
Commit
99f89a6
·
verified ·
1 Parent(s): 6055375

Upload 6 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ fiefei_pic/2024-11-07_09-52-37_9988.png filter=lfs diff=lfs merge=lfs -text
37
+ fiefei_pic/2024-11-07_09-53-47_4210.png filter=lfs diff=lfs merge=lfs -text
fiefei_pic/2024-11-07_09-52-37_9988.png ADDED

Git LFS Details

  • SHA256: 2f31d83bbb5db7023e1f2dbd8735a4e62d0d318aa52b99d7d556d64f7822c443
  • Pointer size: 132 Bytes
  • Size of remote file: 1.18 MB
fiefei_pic/2024-11-07_09-52-37_9988.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ A close-up portrait of a jpop girl with long brown hair. She is wearing a white corset with a lace trim around the neckline. Her face is turned slightly to the left, and she is looking directly at the camera. The background is blurred, but it appears to be an indoor setting with a glass window and a plant.
fiefei_pic/2024-11-07_09-53-47_4210.png ADDED

Git LFS Details

  • SHA256: e7ec034998ee99e3dbe75ce18d9e1d04b8166cbd742b7b064a7da59fa5222b24
  • Pointer size: 132 Bytes
  • Size of remote file: 1.28 MB
fiefei_pic/2024-11-07_09-53-47_4210.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ A close-up portrait of a jpop girl with long brown hair, wearing a red satin dress with a white lace trim. The background is blurred, but it appears to be an indoor setting. The woman is looking directly at the camera.
run.sh ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # 检查是否安装了 bitsandbytes,如果没有则安装 (8-bit Adam 需要)
4
+ # pip install bitsandbytes > /dev/null 2>&1
5
+
6
+ accelerate launch --mixed_precision="bf16" train_zimage_lora.py \
7
+ --pretrained_model_name_or_path="../../../smodels/Z-Image-Turbo" \
8
+ --train_data_dir="../../../datasets/fiefei_pic" \
9
+ --resolution=1024 \
10
+ --train_batch_size=1 \
11
+ --gradient_accumulation_steps=4 \
12
+ --max_train_steps=1000 \
13
+ --learning_rate=1e-4 \
14
+ --mixed_precision="bf16" \
15
+ --output_dir="feifei-zimage-lora" \
16
+ --caption_column="text" \
17
+ --rank=64 \
18
+ --gradient_checkpointing \
19
+ --use_8bit_adam \
20
+ --checkpointing_steps=100 \
21
+ --checkpoints_total_limit=3
train_zimage_lora.py ADDED
@@ -0,0 +1,475 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ import argparse
4
+ import logging
5
+ import math
6
+ import os
7
+ from safetensors.torch import save_file # Add this import
8
+ import random
9
+ import shutil
10
+ import glob
11
+ import gc
12
+ import inspect
13
+ from contextlib import nullcontext
14
+ from pathlib import Path
15
+
16
+ import datasets
17
+ import numpy as np
18
+ import torch
19
+ import torch.nn.functional as F
20
+ import torch.utils.checkpoint
21
+ import transformers
22
+ from accelerate import Accelerator
23
+ from accelerate.logging import get_logger
24
+ from accelerate.utils import ProjectConfiguration, set_seed
25
+ from datasets import load_dataset, Dataset, DatasetDict
26
+ from huggingface_hub import create_repo, upload_folder
27
+ from packaging import version
28
+ from peft import LoraConfig
29
+ from peft.utils import get_peft_model_state_dict, set_peft_model_state_dict
30
+ from torchvision import transforms
31
+ from tqdm.auto import tqdm
32
+
33
+ from transformers import AutoTokenizer, AutoModel, BitsAndBytesConfig
34
+
35
+ import diffusers
36
+ from diffusers import (
37
+ AutoencoderKL,
38
+ DDPMScheduler,
39
+ DiffusionPipeline,
40
+ Transformer2DModel,
41
+ FlowMatchEulerDiscreteScheduler
42
+ )
43
+ from diffusers.optimization import get_scheduler
44
+ from diffusers.training_utils import cast_training_params, compute_snr
45
+ from diffusers.utils import (
46
+ check_min_version,
47
+ convert_state_dict_to_diffusers,
48
+ is_wandb_available,
49
+ )
50
+ from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
51
+ from diffusers.utils.import_utils import is_xformers_available
52
+
53
+ if is_wandb_available():
54
+ import wandb
55
+
56
+ check_min_version("0.24.0")
57
+
58
+ logger = get_logger(__name__, log_level="INFO")
59
+
60
+ def parse_args():
61
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
62
+ parser.add_argument("--pretrained_model_name_or_path", type=str, required=True)
63
+ parser.add_argument("--revision", type=str, default=None)
64
+ parser.add_argument("--variant", type=str, default=None)
65
+ parser.add_argument("--dataset_name", type=str, default=None)
66
+ parser.add_argument("--dataset_config_name", type=str, default=None)
67
+ parser.add_argument("--train_data_dir", type=str, default=None)
68
+ parser.add_argument("--image_column", type=str, default="image")
69
+ parser.add_argument("--caption_column", type=str, default="text")
70
+ parser.add_argument("--validation_prompt", type=str, default=None)
71
+ parser.add_argument("--num_validation_images", type=int, default=4)
72
+ parser.add_argument("--validation_epochs", type=int, default=1)
73
+ parser.add_argument("--max_train_samples", type=int, default=None)
74
+ parser.add_argument("--output_dir", type=str, default="z-image-lora")
75
+ parser.add_argument("--cache_dir", type=str, default=None)
76
+ parser.add_argument("--seed", type=int, default=None)
77
+ parser.add_argument("--resolution", type=int, default=1024)
78
+ parser.add_argument("--center_crop", default=False, action="store_true")
79
+ parser.add_argument("--random_flip", action="store_true")
80
+ parser.add_argument("--train_batch_size", type=int, default=1)
81
+ parser.add_argument("--num_train_epochs", type=int, default=100)
82
+ parser.add_argument("--max_train_steps", type=int, default=None)
83
+ parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
84
+ parser.add_argument("--gradient_checkpointing", action="store_true")
85
+ parser.add_argument("--learning_rate", type=float, default=1e-4)
86
+ parser.add_argument("--scale_lr", action="store_true", default=False)
87
+ parser.add_argument("--lr_scheduler", type=str, default="constant")
88
+ parser.add_argument("--lr_warmup_steps", type=int, default=500)
89
+ parser.add_argument("--snr_gamma", type=float, default=None)
90
+ parser.add_argument("--use_8bit_adam", action="store_true")
91
+ parser.add_argument("--allow_tf32", action="store_true")
92
+ parser.add_argument("--dataloader_num_workers", type=int, default=0)
93
+ parser.add_argument("--adam_beta1", type=float, default=0.9)
94
+ parser.add_argument("--adam_beta2", type=float, default=0.999)
95
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2)
96
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08)
97
+ parser.add_argument("--max_grad_norm", default=1.0, type=float)
98
+ parser.add_argument("--push_to_hub", action="store_true")
99
+ parser.add_argument("--hub_token", type=str, default=None)
100
+ parser.add_argument("--prediction_type", type=str, default=None)
101
+ parser.add_argument("--hub_model_id", type=str, default=None)
102
+ parser.add_argument("--logging_dir", type=str, default="logs")
103
+ parser.add_argument("--mixed_precision", type=str, default=None, choices=["no", "fp16", "bf16"])
104
+ parser.add_argument("--report_to", type=str, default="tensorboard")
105
+ parser.add_argument("--local_rank", type=int, default=-1)
106
+ parser.add_argument("--checkpointing_steps", type=int, default=500)
107
+ parser.add_argument("--checkpoints_total_limit", type=int, default=None)
108
+ parser.add_argument("--resume_from_checkpoint", type=str, default=None)
109
+ parser.add_argument("--enable_xformers_memory_efficient_attention", action="store_true")
110
+ parser.add_argument("--noise_offset", type=float, default=0)
111
+ parser.add_argument("--rank", type=int, default=4)
112
+ parser.add_argument("--image_interpolation_mode", type=str, default="lanczos")
113
+
114
+ args = parser.parse_args()
115
+ return args
116
+
117
+ def main():
118
+ torch.cuda.empty_cache()
119
+ gc.collect()
120
+
121
+ args = parse_args()
122
+ logging_dir = Path(args.output_dir, args.logging_dir)
123
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
124
+
125
+ accelerator = Accelerator(
126
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
127
+ mixed_precision=args.mixed_precision,
128
+ log_with=args.report_to,
129
+ project_config=accelerator_project_config,
130
+ )
131
+
132
+ logging.basicConfig(level=logging.INFO)
133
+ if args.seed is not None:
134
+ set_seed(args.seed)
135
+
136
+ if accelerator.is_main_process:
137
+ if args.output_dir is not None:
138
+ os.makedirs(args.output_dir, exist_ok=True)
139
+
140
+ weight_dtype = torch.float32
141
+ if accelerator.mixed_precision == "fp16":
142
+ weight_dtype = torch.float16
143
+ elif accelerator.mixed_precision == "bf16":
144
+ weight_dtype = torch.bfloat16
145
+
146
+ # ========================================================================
147
+ # PHASE 1: Text Embedding Pre-computation
148
+ # ========================================================================
149
+ logger.info("PHASE 1: Text Embedding Pre-computation")
150
+
151
+ data_entries = []
152
+ image_files = []
153
+ for ext in ["*.png", "*.jpg", "*.jpeg", "*.webp"]:
154
+ image_files.extend(glob.glob(os.path.join(args.train_data_dir, "**", ext), recursive=True))
155
+ image_files.extend(glob.glob(os.path.join(args.train_data_dir, "**", ext.upper()), recursive=True))
156
+
157
+ for img_path in image_files:
158
+ txt_path = os.path.splitext(img_path)[0] + ".txt"
159
+ if os.path.exists(txt_path):
160
+ with open(txt_path, "r", encoding="utf-8") as f:
161
+ caption = f.read().strip()
162
+ if caption:
163
+ data_entries.append({"image_path": img_path, "caption": caption})
164
+
165
+ if not data_entries:
166
+ raise ValueError("No images found!")
167
+
168
+ quantization_config = BitsAndBytesConfig(
169
+ load_in_4bit=True,
170
+ bnb_4bit_quant_type="nf4",
171
+ bnb_4bit_compute_dtype=weight_dtype,
172
+ )
173
+
174
+ tokenizer = AutoTokenizer.from_pretrained(
175
+ args.pretrained_model_name_or_path,
176
+ subfolder="tokenizer",
177
+ revision=args.revision,
178
+ trust_remote_code=True
179
+ )
180
+ if hasattr(tokenizer, "pad_token") and tokenizer.pad_token is None:
181
+ if hasattr(tokenizer, "eos_token"):
182
+ tokenizer.pad_token = tokenizer.eos_token
183
+ else:
184
+ tokenizer.add_special_tokens({'pad_token': '<pad>'})
185
+
186
+ text_encoder = AutoModel.from_pretrained(
187
+ args.pretrained_model_name_or_path,
188
+ subfolder="text_encoder",
189
+ quantization_config=quantization_config,
190
+ trust_remote_code=True,
191
+ revision=args.revision,
192
+ )
193
+
194
+ precomputed_dataset = []
195
+ batch_size = 1
196
+
197
+ logger.info(f"Encoding {len(data_entries)} captions...")
198
+ for i in tqdm(range(0, len(data_entries), batch_size), desc="Encoding"):
199
+ batch = data_entries[i : i + batch_size]
200
+ captions = [x["caption"] for x in batch]
201
+
202
+ inputs = tokenizer(
203
+ captions,
204
+ max_length=512,
205
+ padding="max_length",
206
+ truncation=True,
207
+ return_tensors="pt"
208
+ )
209
+ input_ids = inputs.input_ids.to(accelerator.device)
210
+
211
+ with torch.no_grad():
212
+ encoder_outputs = text_encoder(input_ids, output_hidden_states=True, return_dict=True)
213
+ embeddings = encoder_outputs.last_hidden_state.to(dtype=weight_dtype).cpu()
214
+
215
+ for j, item in enumerate(batch):
216
+ precomputed_dataset.append({
217
+ "image": item["image_path"],
218
+ "encoder_hidden_states": embeddings[j]
219
+ })
220
+
221
+ del text_encoder
222
+ del tokenizer
223
+ torch.cuda.empty_cache()
224
+ gc.collect()
225
+
226
+ # ========================================================================
227
+ # PHASE 2: Model Loading
228
+ # ========================================================================
229
+ logger.info("PHASE 2: Loading Transformer...")
230
+
231
+ pipe = DiffusionPipeline.from_pretrained(
232
+ args.pretrained_model_name_or_path,
233
+ trust_remote_code=True,
234
+ revision=args.revision,
235
+ variant=args.variant,
236
+ text_encoder=None,
237
+ tokenizer=None
238
+ )
239
+
240
+ transformer = pipe.transformer
241
+ vae = pipe.vae
242
+ noise_scheduler = pipe.scheduler
243
+ del pipe
244
+
245
+ # Analyze Params
246
+ forward_signature = inspect.signature(transformer.forward)
247
+ params = forward_signature.parameters
248
+ param_names = list(params.keys())
249
+ logger.info(f"Detected Transformer forward params: {param_names}")
250
+
251
+ input_arg = "hidden_states"
252
+ if "x" in param_names: input_arg = "x"
253
+ elif "sample" in param_names: input_arg = "sample"
254
+
255
+ cond_arg = "encoder_hidden_states"
256
+ if "cap_feats" in param_names: cond_arg = "cap_feats"
257
+ elif "context" in param_names: cond_arg = "context"
258
+
259
+ logger.info(f"Mapping: Input='{input_arg}', Cond='{cond_arg}'")
260
+
261
+ transformer.requires_grad_(False)
262
+ vae.requires_grad_(False)
263
+
264
+ transformer.to(accelerator.device, dtype=weight_dtype)
265
+ vae.to(accelerator.device, dtype=weight_dtype)
266
+
267
+ transformer_lora_config = LoraConfig(
268
+ r=args.rank,
269
+ lora_alpha=args.rank,
270
+ init_lora_weights="gaussian",
271
+ target_modules=["to_k", "to_q", "to_v", "to_out.0"],
272
+ )
273
+ transformer.add_adapter(transformer_lora_config)
274
+
275
+ if args.mixed_precision == "fp16":
276
+ cast_training_params(transformer, dtype=torch.float32)
277
+
278
+ if args.enable_xformers_memory_efficient_attention:
279
+ if is_xformers_available():
280
+ if hasattr(transformer, "enable_xformers_memory_efficient_attention"):
281
+ transformer.enable_xformers_memory_efficient_attention()
282
+
283
+ if args.gradient_checkpointing:
284
+ transformer.enable_gradient_checkpointing()
285
+
286
+ if args.scale_lr:
287
+ args.learning_rate = (
288
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
289
+ )
290
+
291
+ if args.use_8bit_adam:
292
+ import bitsandbytes as bnb
293
+ optimizer_cls = bnb.optim.AdamW8bit
294
+ else:
295
+ optimizer_cls = torch.optim.AdamW
296
+
297
+ lora_layers = filter(lambda p: p.requires_grad, transformer.parameters())
298
+ optimizer = optimizer_cls(
299
+ lora_layers,
300
+ lr=args.learning_rate,
301
+ betas=(args.adam_beta1, args.adam_beta2),
302
+ weight_decay=args.adam_weight_decay,
303
+ eps=args.adam_epsilon,
304
+ )
305
+
306
+ dataset = Dataset.from_list(precomputed_dataset).cast_column("image", datasets.Image())
307
+ dataset = DatasetDict({"train": dataset})
308
+
309
+ interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper(), None)
310
+ train_transforms = transforms.Compose(
311
+ [
312
+ transforms.Resize(args.resolution, interpolation=interpolation),
313
+ transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
314
+ transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
315
+ transforms.ToTensor(),
316
+ transforms.Normalize([0.5], [0.5]),
317
+ ]
318
+ )
319
+
320
+ def preprocess_train(examples):
321
+ images = [image.convert("RGB") for image in examples["image"]]
322
+ examples["pixel_values"] = [train_transforms(image) for image in images]
323
+ return examples
324
+
325
+ with accelerator.main_process_first():
326
+ if args.max_train_samples is not None:
327
+ dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
328
+ train_dataset = dataset["train"].with_transform(preprocess_train)
329
+
330
+ def collate_fn(examples):
331
+ pixel_values = torch.stack([example["pixel_values"] for example in examples])
332
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
333
+ encoder_hidden_states = torch.stack([torch.tensor(example["encoder_hidden_states"]) for example in examples])
334
+ return {"pixel_values": pixel_values, "encoder_hidden_states": encoder_hidden_states}
335
+
336
+ train_dataloader = torch.utils.data.DataLoader(
337
+ train_dataset,
338
+ shuffle=True,
339
+ collate_fn=collate_fn,
340
+ batch_size=args.train_batch_size,
341
+ num_workers=args.dataloader_num_workers,
342
+ )
343
+
344
+ num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
345
+ if args.max_train_steps is None:
346
+ len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
347
+ num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
348
+ num_training_steps_for_scheduler = args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes
349
+ else:
350
+ num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
351
+
352
+ lr_scheduler = get_scheduler(
353
+ args.lr_scheduler,
354
+ optimizer=optimizer,
355
+ num_warmup_steps=num_warmup_steps_for_scheduler,
356
+ num_training_steps=num_training_steps_for_scheduler,
357
+ )
358
+
359
+ transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
360
+ transformer, optimizer, train_dataloader, lr_scheduler
361
+ )
362
+
363
+ logger.info("***** Running training (Final Fix) *****")
364
+ global_step = 0
365
+ first_epoch = 0
366
+
367
+ progress_bar = tqdm(range(0, args.max_train_steps), initial=0, desc="Steps", disable=not accelerator.is_local_main_process)
368
+
369
+ for epoch in range(first_epoch, args.num_train_epochs):
370
+ transformer.train()
371
+ train_loss = 0.0
372
+ for step, batch in enumerate(train_dataloader):
373
+ with accelerator.accumulate(transformer):
374
+ latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
375
+ latents = latents * vae.config.scaling_factor
376
+ noise = torch.randn_like(latents)
377
+ bsz = latents.shape[0]
378
+
379
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device).long()
380
+
381
+ sigmas = timesteps.flatten() / noise_scheduler.config.num_train_timesteps
382
+ while sigmas.ndim < latents.ndim:
383
+ sigmas = sigmas.unsqueeze(-1)
384
+
385
+ noisy_latents = (1.0 - sigmas) * latents + sigmas * noise
386
+ target = noise - latents
387
+
388
+ # <--- 核心修复:Unsqueeze to (Batch, Channels, Frames=1, H, W) --->
389
+ noisy_latents = noisy_latents.unsqueeze(2)
390
+
391
+ encoder_hidden_states = batch["encoder_hidden_states"].to(dtype=weight_dtype)
392
+
393
+ forward_kwargs = {
394
+ input_arg: noisy_latents,
395
+ "t": timesteps, # Ensure t is passed
396
+ cond_arg: encoder_hidden_states,
397
+ "return_dict": False
398
+ }
399
+
400
+ model_pred = transformer(**forward_kwargs)[0]
401
+
402
+ # 1. Handle case where output is a list (fixes AttributeError: 'list' object has no attribute 'ndim')
403
+ if isinstance(model_pred, list):
404
+ model_pred = model_pred[0]
405
+
406
+ # 2. Handle 5D output (Video/Motion modules)
407
+ if model_pred.ndim == 5:
408
+ model_pred = model_pred.squeeze(2)
409
+
410
+ # 3. Fix Shape Mismatch: Target [B, C, H, W] vs Pred [C, B, H, W] or similar
411
+ # The error showed Target [1, 16, 128, 128] vs Pred [16, 1, 128, 128]
412
+ if model_pred.shape != target.shape:
413
+ # If dimensions are just swapped (e.g., [16, 1] vs [1, 16]), transpose them
414
+ if model_pred.shape[0] == target.shape[1] and model_pred.shape[1] == target.shape[0]:
415
+ model_pred = model_pred.transpose(0, 1)
416
+ # If shapes still don't match but have same number of elements, force reshape
417
+ elif model_pred.numel() == target.numel():
418
+ model_pred = model_pred.reshape(target.shape)
419
+
420
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
421
+ avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
422
+ train_loss += avg_loss.item() / args.gradient_accumulation_steps
423
+
424
+ accelerator.backward(loss)
425
+ if accelerator.sync_gradients:
426
+ accelerator.clip_grad_norm_(lora_layers, args.max_grad_norm)
427
+ optimizer.step()
428
+ lr_scheduler.step()
429
+ optimizer.zero_grad()
430
+
431
+ if accelerator.sync_gradients:
432
+ progress_bar.update(1)
433
+ global_step += 1
434
+ accelerator.log({"train_loss": train_loss}, step=global_step)
435
+ train_loss = 0.0
436
+
437
+ if global_step % args.checkpointing_steps == 0:
438
+ if accelerator.is_main_process:
439
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
440
+ accelerator.save_state(save_path)
441
+
442
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
443
+ progress_bar.set_postfix(**logs)
444
+
445
+ if global_step >= args.max_train_steps:
446
+ break
447
+
448
+ # === 【核心修改:这里加两行】 ===
449
+ # 如果步数到了,强制跳出外层的 Epoch 循环,否则它会空转
450
+ if global_step >= args.max_train_steps:
451
+ break
452
+ # ============================
453
+
454
+ # === 这里的缩进是 4 个空格 ===
455
+ if accelerator.is_main_process:
456
+ transformer = accelerator.unwrap_model(transformer)
457
+ transformer = transformer.to(torch.float32)
458
+ transformer_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(transformer))
459
+
460
+ # === 【核心修复】给参数名加上 "transformer." 前缀 ===
461
+ # 这一步非常关键,没有它,load_lora_weights 就认不出来
462
+ new_state_dict = {}
463
+ for k, v in transformer_lora_state_dict.items():
464
+ new_state_dict[f"transformer.{k}"] = v
465
+ # ================================================
466
+
467
+ # 使用 safetensors 保存加上前缀后的权重
468
+ save_path = os.path.join(args.output_dir, "pytorch_lora_weights.safetensors")
469
+ save_file(new_state_dict, save_path) # 注意这里传的是 new_state_dict
470
+ logger.info(f"Saved LoRA weights to {save_path}")
471
+
472
+ accelerator.end_training()
473
+
474
+ if __name__ == "__main__":
475
+ main()