aifeifei798 commited on
Commit
81fcd33
·
verified ·
1 Parent(s): 8243970

Update test_zimage.py

Browse files
Files changed (1) hide show
  1. test_zimage.py +54 -54
test_zimage.py CHANGED
@@ -1,55 +1,55 @@
1
- import torch
2
- from diffusers import ZImagePipeline
3
- import os
4
-
5
- # 1. Load the pipeline
6
- # Use bfloat16 for optimal performance on supported GPUs
7
- pipe = ZImagePipeline.from_pretrained(
8
- "../../../smodels/Z-Image-Turbo",
9
- torch_dtype=torch.bfloat16,
10
- low_cpu_mem_usage=False,
11
- )
12
-
13
- # =================【这里是新增的加载 LoRA 代码】=================
14
- # 指向你刚才训练输出的文件夹路径
15
- lora_dir = "./feifei-zimage-lora"
16
- lora_file = "pytorch_lora_weights.safetensors"
17
- full_path = os.path.join(lora_dir, lora_file)
18
-
19
- if os.path.exists(full_path):
20
- print(f"正在加载 LoRA: {full_path}")
21
- try:
22
- # adapter_name 可以随意起,用来标记这个 LoRA
23
- pipe.load_lora_weights(lora_dir, weight_name=lora_file, adapter_name="feifei")
24
- print("✅ LoRA 加载成功!")
25
-
26
- # [可选] 设置 LoRA 的权重强度 (1.0 = 100% 强度, 0.5 = 50%)
27
- # pipe.set_adapters(["feifei"], adapter_weights=[1.0])
28
-
29
- except Exception as e:
30
- print(f"❌ LoRA 加载失败: {e}")
31
- print("可能是键名不匹配,或者文件损坏。")
32
- else:
33
- print(f"⚠️ 找不到 LoRA 文件: {full_path}")
34
- # ===============================================================
35
-
36
- pipe.to("cuda")
37
-
38
- # [Optional] Attention Backend
39
- # pipe.transformer.set_attention_backend("flash")
40
-
41
- prompt = "jpop model in bikini at sea"
42
-
43
- # 2. Generate Image
44
- image = pipe(
45
- prompt=prompt,
46
- height=1024,
47
- width=1024,
48
- num_inference_steps=9,
49
- guidance_scale=0.0,
50
- generator=torch.Generator("cuda").manual_seed(42),
51
- # cross_attention_kwargs={"scale": 1.0} # 另一种控制 LoRA 强度的方法
52
- ).images[0]
53
-
54
- image.save("example_lora_test.png")
55
  print("图像已保存为 example_lora_test.png")
 
1
+ import torch
2
+ from diffusers import ZImagePipeline
3
+ import os
4
+
5
+ # 1. Load the pipeline
6
+ # Use bfloat16 for optimal performance on supported GPUs
7
+ pipe = ZImagePipeline.from_pretrained(
8
+ "./Z-Image-Turbo",
9
+ torch_dtype=torch.bfloat16,
10
+ low_cpu_mem_usage=False,
11
+ )
12
+
13
+ # =================【这里是新增的加载 LoRA 代码】=================
14
+ # 指向你刚才训练输出的文件夹路径
15
+ lora_dir = "./feifei-zimage-lora"
16
+ lora_file = "pytorch_lora_weights.safetensors"
17
+ full_path = os.path.join(lora_dir, lora_file)
18
+
19
+ if os.path.exists(full_path):
20
+ print(f"正在加载 LoRA: {full_path}")
21
+ try:
22
+ # adapter_name 可以随意起,用来标记这个 LoRA
23
+ pipe.load_lora_weights(lora_dir, weight_name=lora_file, adapter_name="feifei")
24
+ print("✅ LoRA 加载成功!")
25
+
26
+ # [可选] 设置 LoRA 的权重强度 (1.0 = 100% 强度, 0.5 = 50%)
27
+ # pipe.set_adapters(["feifei"], adapter_weights=[1.0])
28
+
29
+ except Exception as e:
30
+ print(f"❌ LoRA 加载失败: {e}")
31
+ print("可能是键名不匹配,或者文件损坏。")
32
+ else:
33
+ print(f"⚠️ 找不到 LoRA 文件: {full_path}")
34
+ # ===============================================================
35
+
36
+ pipe.to("cuda")
37
+
38
+ # [Optional] Attention Backend
39
+ # pipe.transformer.set_attention_backend("flash")
40
+
41
+ prompt = "jpop model in bikini at sea"
42
+
43
+ # 2. Generate Image
44
+ image = pipe(
45
+ prompt=prompt,
46
+ height=1024,
47
+ width=1024,
48
+ num_inference_steps=9,
49
+ guidance_scale=0.0,
50
+ generator=torch.Generator("cuda").manual_seed(42),
51
+ # cross_attention_kwargs={"scale": 1.0} # 另一种控制 LoRA 强度的方法
52
+ ).images[0]
53
+
54
+ image.save("example_lora_test.png")
55
  print("图像已保存为 example_lora_test.png")