kayte0342 commited on
Commit
50edcdc
·
verified ·
1 Parent(s): 5252516

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -46
app.py CHANGED
@@ -160,58 +160,50 @@ def generate_image_to_image(prompt_mash, image_input_path, image_strength, steps
160
  return final_image
161
 
162
  @spaces.GPU(duration=70)
163
- def run_lora(prompt, image_input, image_strength, cfg_scale, steps, selected_indices_json, randomize_seed, seed, width, height, lora_scale, progress=gr.Progress(track_tqdm=True)):
164
  import json
165
- # Parse the JSON string into a list of indices
166
  selected_indices = json.loads(selected_indices_json)
 
167
 
168
- # Ensure at least one LoRA is selected
169
- if not selected_indices or len(selected_indices) == 0:
170
  raise gr.Error("You must select at least one LoRA before proceeding.")
171
 
172
- # Combine trigger words from all selected LoRAs into the prompt
173
  prompt_mash = prompt
174
  for idx in selected_indices:
175
  selected_lora = loras[idx]
176
  if "trigger_word" in selected_lora and selected_lora["trigger_word"]:
177
- # Prepend each trigger word to the prompt; you can adjust the order or separator as needed
178
  prompt_mash = f"{selected_lora['trigger_word']} {prompt_mash}"
179
 
180
- # Unload any previously loaded LoRA weights
181
  with calculateDuration("Unloading LoRA"):
182
  pipe.unload_lora_weights()
183
  pipe_i2i.unload_lora_weights()
184
 
185
- # Load each selected LoRA weight sequentially
186
  with calculateDuration("Loading LoRA weights"):
187
- # Use the image-to-image pipeline if an input image is provided, else the text-to-image pipeline
188
  pipe_to_use = pipe_i2i if image_input is not None else pipe
189
  for idx in selected_indices:
190
  selected_lora = loras[idx]
191
  weight_name = selected_lora.get("weights", None)
 
 
 
192
  pipe_to_use.load_lora_weights(
193
- selected_lora["repo"],
194
- weight_name=weight_name,
195
- low_cpu_mem_usage=True
 
196
  )
197
 
198
- # Optionally randomize the seed if requested
199
  with calculateDuration("Randomizing seed"):
200
  if randomize_seed:
201
- seed = random.randint(0, MAX_SEED)
202
 
203
- # Generate image(s)
204
  if image_input is not None:
205
- # Image-to-image generation
206
- final_image = generate_image_to_image(
207
- prompt_mash, image_input, image_strength, steps, cfg_scale, width, height, lora_scale, seed
208
- )
209
  yield final_image, seed, gr.update(visible=False)
210
  else:
211
- # Text-to-image generation
212
- image_generator = generate_image(
213
- prompt_mash, steps, seed, cfg_scale, width, height, lora_scale, progress
214
- )
215
  final_image = None
216
  step_counter = 0
217
  for image in image_generator:
@@ -222,6 +214,7 @@ def run_lora(prompt, image_input, image_strength, cfg_scale, steps, selected_ind
222
  yield final_image, seed, gr.update(value=progress_bar, visible=False)
223
 
224
 
 
225
 
226
  def get_huggingface_safetensors(link):
227
  split_link = link.split("/")
@@ -301,18 +294,18 @@ def remove_custom_lora():
301
  run_lora.zerogpu = True
302
 
303
  css = '''
304
- #gen_btn{height: 100%}
305
- #gen_column{align-self: stretch}
306
- #title{text-align: center}
307
- #title h1{font-size: 3em; display:inline-flex; align-items:center}
308
- #title img{width: 100px; margin-right: 0.5em}
309
- #lora_list{background: var(--block-background-fill);padding: 0 1em .3em; font-size: 90%}
310
- .card_internal{display: flex;height: 100px;margin-top: .5em}
311
- .card_internal img{margin-right: 1em}
312
- .styler{--form-gap-width: 0px !important}
313
- #progress{height:30px}
314
- .progress-container {width: 100%;height: 30px;background-color: #f0f0f0;border-radius: 15px;overflow: hidden;margin-bottom: 20px}
315
- .progress-bar {height: 100%;background-color: #4f46e5;width: calc(var(--current) / var(--total) * 100%);transition: width 0.5s ease-in-out}
316
  '''
317
  font = [gr.themes.GoogleFont("Source Sans Pro"), "Arial", "sans-serif"]
318
 
@@ -321,8 +314,10 @@ with gr.Blocks(theme=gr.themes.Soft(font=font), css=css, delete_cache=(60, 60))
321
  """<h1><img src="https://huggingface.co/spaces/kayte0342/test/resolve/main/DA4BE61E-A0BD-4254-A1B6-AD3C05D18A9C%20(1).png?download=true" alt="LoRA"> FLUX LoRA Kayte's Space</h1>""",
322
  elem_id="title",
323
  )
324
- # Hidden textbox to store the JSON string of selected indices
 
325
  selected_indices_hidden = gr.Textbox(value="[]", visible=False)
 
326
 
327
  with gr.Row():
328
  with gr.Column(scale=3):
@@ -333,15 +328,19 @@ with gr.Blocks(theme=gr.themes.Soft(font=font), css=css, delete_cache=(60, 60))
333
  with gr.Row():
334
  with gr.Column():
335
  selected_info = gr.Markdown("")
336
- # Build a custom layout for LoRA selection with checkboxes.
337
  lora_selection_container = gr.Column()
 
338
  lora_checkbox_list = []
 
339
  for idx, lora in enumerate(loras):
340
  with gr.Row():
341
- gr.Image(value=lora["image"], label=lora["title"], height=100)
342
  checkbox = gr.Checkbox(label="Select", value=False, elem_id=f"lora_checkbox_{idx}")
 
343
  lora_checkbox_list.append(checkbox)
344
- gr.Markdown("[Check the list of FLUX LoRas](https://huggingface.co/models?other=base_model:adapter:black-forest-labs/FLUX.1-dev)", elem_id="lora_list")
 
345
  with gr.Column():
346
  progress_bar = gr.Markdown(elem_id="progress", visible=False)
347
  result = gr.Image(label="Generated Image")
@@ -361,27 +360,35 @@ with gr.Blocks(theme=gr.themes.Soft(font=font), css=css, delete_cache=(60, 60))
361
  with gr.Row():
362
  randomize_seed = gr.Checkbox(True, label="Randomize seed")
363
  seed = gr.Slider(label="Seed", minimum=0, maximum=2**32-1, step=1, value=0, randomize=True)
364
- lora_scale = gr.Slider(label="LoRA Scale", minimum=0, maximum=3, step=0.01, value=0.95)
365
 
366
- # Function to combine the checkbox values into a JSON string.
367
  def combine_selections(*checkbox_values):
368
  selected_indices = [i for i, v in enumerate(checkbox_values) if v]
369
  return json.dumps(selected_indices)
370
 
371
- # Chain the update: When the Generate button is clicked,
372
- # first update the hidden state with combine_selections,
373
- # then run run_lora using the updated hidden state.
 
 
 
 
374
  generate_button.click(
375
  combine_selections,
376
  inputs=lora_checkbox_list,
377
  outputs=selected_indices_hidden
 
 
 
 
378
  ).then(
379
  run_lora,
380
- inputs=[prompt, input_image, image_strength, cfg_scale, steps, selected_indices_hidden, randomize_seed, seed, width, height, lora_scale],
381
  outputs=[result, seed, progress_bar]
382
  )
383
 
384
- # Optionally, update the selected_info markdown when the hidden state changes.
385
  def update_info(selected_json):
386
  selected_indices = json.loads(selected_json)
387
  if selected_indices:
 
160
  return final_image
161
 
162
  @spaces.GPU(duration=70)
163
+ def run_lora(prompt, image_input, image_strength, cfg_scale, steps, selected_indices_json, selected_weights_json, randomize_seed, seed, width, height, global_lora_scale, progress=gr.Progress(track_tqdm=True)):
164
  import json
165
+ # Parse the JSON strings
166
  selected_indices = json.loads(selected_indices_json)
167
+ selected_weights = json.loads(selected_weights_json) if selected_weights_json else {}
168
 
169
+ if not selected_indices:
 
170
  raise gr.Error("You must select at least one LoRA before proceeding.")
171
 
172
+ # Combine trigger words from all selected LoRAs
173
  prompt_mash = prompt
174
  for idx in selected_indices:
175
  selected_lora = loras[idx]
176
  if "trigger_word" in selected_lora and selected_lora["trigger_word"]:
 
177
  prompt_mash = f"{selected_lora['trigger_word']} {prompt_mash}"
178
 
 
179
  with calculateDuration("Unloading LoRA"):
180
  pipe.unload_lora_weights()
181
  pipe_i2i.unload_lora_weights()
182
 
 
183
  with calculateDuration("Loading LoRA weights"):
 
184
  pipe_to_use = pipe_i2i if image_input is not None else pipe
185
  for idx in selected_indices:
186
  selected_lora = loras[idx]
187
  weight_name = selected_lora.get("weights", None)
188
+ # Get the individual weight for this LoRA from the selected_weights mapping.
189
+ # If not found, default to 0.95.
190
+ lora_weight = selected_weights.get(str(idx), 0.95)
191
  pipe_to_use.load_lora_weights(
192
+ selected_lora["repo"],
193
+ weight_name=weight_name,
194
+ low_cpu_mem_usage=True,
195
+ lora_weight=lora_weight # This parameter should be supported by your load function.
196
  )
197
 
 
198
  with calculateDuration("Randomizing seed"):
199
  if randomize_seed:
200
+ seed = random.randint(0, 2**32-1)
201
 
 
202
  if image_input is not None:
203
+ final_image = generate_image_to_image(prompt_mash, image_input, image_strength, steps, cfg_scale, width, height, global_lora_scale, seed)
 
 
 
204
  yield final_image, seed, gr.update(visible=False)
205
  else:
206
+ image_generator = generate_image(prompt_mash, steps, seed, cfg_scale, width, height, global_lora_scale, progress)
 
 
 
207
  final_image = None
208
  step_counter = 0
209
  for image in image_generator:
 
214
  yield final_image, seed, gr.update(value=progress_bar, visible=False)
215
 
216
 
217
+
218
 
219
  def get_huggingface_safetensors(link):
220
  split_link = link.split("/")
 
294
  run_lora.zerogpu = True
295
 
296
  css = '''
297
+ #gen_btn { height: 100%; }
298
+ #gen_column { align-self: stretch; }
299
+ #title { text-align: center; }
300
+ #title h1 { font-size: 3em; display: inline-flex; align-items: center; }
301
+ #title img { width: 100px; margin-right: 0.5em; }
302
+ #lora_list { background: var(--block-background-fill); padding: 0 1em .3em; font-size: 90%; }
303
+ .card_internal { display: flex; height: 100px; margin-top: .5em; }
304
+ .card_internal img { margin-right: 1em; }
305
+ .styler { --form-gap-width: 0px !important; }
306
+ #progress { height: 30px; }
307
+ .progress-container { width: 100%; height: 30px; background-color: #f0f0f0; border-radius: 15px; overflow: hidden; margin-bottom: 20px; }
308
+ .progress-bar { height: 100%; background-color: #4f46e5; width: calc(var(--current) / var(--total) * 100%); transition: width 0.5s ease-in-out; }
309
  '''
310
  font = [gr.themes.GoogleFont("Source Sans Pro"), "Arial", "sans-serif"]
311
 
 
314
  """<h1><img src="https://huggingface.co/spaces/kayte0342/test/resolve/main/DA4BE61E-A0BD-4254-A1B6-AD3C05D18A9C%20(1).png?download=true" alt="LoRA"> FLUX LoRA Kayte's Space</h1>""",
315
  elem_id="title",
316
  )
317
+
318
+ # Hidden textboxes to store the JSON outputs:
319
  selected_indices_hidden = gr.Textbox(value="[]", visible=False)
320
+ selected_weights_hidden = gr.Textbox(value="{}", visible=False)
321
 
322
  with gr.Row():
323
  with gr.Column(scale=3):
 
328
  with gr.Row():
329
  with gr.Column():
330
  selected_info = gr.Markdown("")
331
+ # Build a custom layout for LoRA selection.
332
  lora_selection_container = gr.Column()
333
+ # We'll collect checkbox and slider components in lists.
334
  lora_checkbox_list = []
335
+ lora_slider_list = []
336
  for idx, lora in enumerate(loras):
337
  with gr.Row():
338
+ gr.Image(label=lora["title"], height=100)
339
  checkbox = gr.Checkbox(label="Select", value=False, elem_id=f"lora_checkbox_{idx}")
340
+ slider = gr.Slider(label="Weight", minimum=0, maximum=3, step=0.01, value=0.95, elem_id=f"lora_weight_{idx}")
341
  lora_checkbox_list.append(checkbox)
342
+ lora_slider_list.append(slider)
343
+ gr.Markdown("[Check the list of FLUX LoRAs](https://huggingface.co/models?other=base_model:adapter:black-forest-labs/FLUX.1-dev)", elem_id="lora_list")
344
  with gr.Column():
345
  progress_bar = gr.Markdown(elem_id="progress", visible=False)
346
  result = gr.Image(label="Generated Image")
 
360
  with gr.Row():
361
  randomize_seed = gr.Checkbox(True, label="Randomize seed")
362
  seed = gr.Slider(label="Seed", minimum=0, maximum=2**32-1, step=1, value=0, randomize=True)
363
+ lora_scale = gr.Slider(label="Global LoRA Scale", minimum=0, maximum=3, step=0.01, value=0.95)
364
 
365
+ # Function to combine checkbox selections into a JSON list of indices.
366
  def combine_selections(*checkbox_values):
367
  selected_indices = [i for i, v in enumerate(checkbox_values) if v]
368
  return json.dumps(selected_indices)
369
 
370
+ # Function to combine all slider values into a JSON dictionary mapping index to weight.
371
+ def combine_weights(*slider_values):
372
+ weights = {str(i): v for i, v in enumerate(slider_values)}
373
+ return json.dumps(weights)
374
+
375
+ # Chain the updates when the Generate button is clicked:
376
+ # First, update the checkbox hidden state, then update the slider hidden state, then call run_lora.
377
  generate_button.click(
378
  combine_selections,
379
  inputs=lora_checkbox_list,
380
  outputs=selected_indices_hidden
381
+ ).then(
382
+ combine_weights,
383
+ inputs=lora_slider_list,
384
+ outputs=selected_weights_hidden
385
  ).then(
386
  run_lora,
387
+ inputs=[prompt, input_image, image_strength, cfg_scale, steps, selected_indices_hidden, selected_weights_hidden, randomize_seed, seed, width, height, lora_scale],
388
  outputs=[result, seed, progress_bar]
389
  )
390
 
391
+ # Update the selected_info display when the selected_indices_hidden changes.
392
  def update_info(selected_json):
393
  selected_indices = json.loads(selected_json)
394
  if selected_indices: