RioShiina commited on
Commit
e9c4f1a
·
verified ·
1 Parent(s): 23bb9cf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +620 -522
app.py CHANGED
@@ -1,523 +1,621 @@
1
- import spaces
2
- import gradio as gr
3
- import numpy as np
4
- import PIL.Image
5
- from PIL import Image, PngImagePlugin
6
- import random
7
- from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, DPMSolverMultistepScheduler, DDIMScheduler, UniPCMultistepScheduler, HeunDiscreteScheduler, LMSDiscreteScheduler
8
- import torch
9
- from compel import Compel, ReturnedEmbeddingsType
10
- import requests
11
- import os
12
- import re
13
- import gc
14
- from huggingface_hub import hf_hub_download, snapshot_download
15
- import time
16
-
17
- # This dummy function is required to pass the Hugging Face Spaces startup check for GPU apps.
18
- @spaces.GPU(duration=60)
19
- def dummy_gpu_for_startup():
20
- print("Dummy function for startup check executed. This is normal.")
21
- return "Startup check passed."
22
-
23
- # --- Constants ---
24
- MAX_LORAS = 5
25
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
26
- MAX_SEED = np.iinfo(np.int64).max
27
- MAX_IMAGE_SIZE = 1216
28
- SAMPLER_MAP = {
29
- "Euler a": EulerAncestralDiscreteScheduler,
30
- "Euler": EulerDiscreteScheduler,
31
- "DPM++ 2M Karras": DPMSolverMultistepScheduler,
32
- "DDIM": DDIMScheduler,
33
- "UniPC": UniPCMultistepScheduler,
34
- "Heun": HeunDiscreteScheduler,
35
- "LMS": LMSDiscreteScheduler,
36
- }
37
- SCHEDULE_TYPE_MAP = ["Default", "Karras", "Uniform", "SGM Uniform"]
38
- DEFAULT_SCHEDULE_TYPE = "Default"
39
- DEFAULT_SAMPLER = "Euler a"
40
- DEFAULT_NEGATIVE_PROMPT = "monochrome, (low quality, worst quality:1.2), 3d, watermark, signature, ugly, poorly drawn,"
41
- DOWNLOAD_DIR = "/tmp/loras"
42
- os.makedirs(DOWNLOAD_DIR, exist_ok=True)
43
-
44
- # --- Model Lists ---
45
- MODEL_LIST = [
46
- "dhead/wai-nsfw-illustrious-sdxl-v140-sdxl",
47
- "Laxhar/noobai-XL-Vpred-1.0",
48
- "John6666/hassaku-xl-illustrious-v30-sdxl",
49
- "RedRayz/hikari_noob_v-pred_1.2.2",
50
- "bluepen5805/noob_v_pencil-XL",
51
- "Laxhar/noobai-XL-1.1"
52
- ]
53
-
54
- # --- List of V-Prediction Models ---
55
- V_PREDICTION_MODELS = [
56
- "Laxhar/noobai-XL-Vpred-1.0",
57
- "RedRayz/hikari_noob_v-pred_1.2.2",
58
- "bluepen5805/noob_v_pencil-XL"
59
- ]
60
-
61
- # --- Dictionary for single-file models now stores the filename ---
62
- SINGLE_FILE_MODELS = {
63
- "bluepen5805/noob_v_pencil-XL": "noob_v_pencil-XL-v3.0.0.safetensors"
64
- }
65
-
66
- # --- Model Hash to Name Mapping ---
67
- HASH_TO_MODEL_MAP = {
68
- "bdb59bac77": "dhead/wai-nsfw-illustrious-sdxl-v140-sdxl",
69
- "ea349eeae8": "Laxhar/noobai-XL-Vpred-1.0",
70
- "b4fb5f829a": "John6666/hassaku-xl-illustrious-v30-sdxl",
71
- "6681e8e4b1": "Laxhar/noobai-XL-1.1",
72
- "90b7911a78": "bluepen5805/noob_v_pencil-XL",
73
- "874170688a": "RedRayz/hikari_noob_v-pred_1.2.2"
74
- }
75
-
76
- def download_all_base_models_on_startup():
77
- """Downloads all base models listed in MODEL_LIST when the app starts."""
78
- print("--- Starting pre-download of all base models ---")
79
- for model_name in MODEL_LIST:
80
- try:
81
- print(f"Downloading: {model_name}...")
82
- start_time = time.time()
83
- # Handle single-file models
84
- if model_name in SINGLE_FILE_MODELS:
85
- filename = SINGLE_FILE_MODELS[model_name]
86
- hf_hub_download(repo_id=model_name, filename=filename)
87
- # Handle standard diffusers models
88
- else:
89
- snapshot_download(repo_id=model_name, ignore_patterns=["*.onnx", "*.flax"])
90
- end_time = time.time()
91
- print(f"✅ Successfully downloaded {model_name} in {end_time - start_time:.2f} seconds.")
92
- except Exception as e:
93
- print(f" Failed to download {model_name}: {e}")
94
- finally:
95
- # Clean up to conserve memory
96
- gc.collect()
97
- if torch.cuda.is_available():
98
- torch.cuda.empty_cache()
99
- print("--- Finished pre-downloading all base models ---")
100
-
101
-
102
- def get_civitai_file_info(version_id):
103
- """Gets the file metadata for a model version via the Civitai API."""
104
- api_url = f"https://civitai.com/api/v1/model-versions/{version_id}"
105
- try:
106
- response = requests.get(api_url)
107
- response.raise_for_status()
108
- data = response.json()
109
- for file_data in data.get('files', []):
110
- if file_data['name'].endswith('.safetensors'):
111
- return file_data
112
- if data.get('files'):
113
- return data['files'][0]
114
- return None
115
- except Exception as e:
116
- print(f"Could not get file info from Civitai API: {e}")
117
- return None
118
-
119
- def download_file(url, save_path, api_key=None, progress=None, desc=""):
120
- """Downloads a file, skipping if it already exists."""
121
- if os.path.exists(save_path):
122
- return f"File already exists: {os.path.basename(save_path)}"
123
-
124
- headers = {}
125
- if api_key and api_key.strip():
126
- headers['Authorization'] = f'Bearer {api_key}'
127
-
128
- try:
129
- if progress: progress(0, desc=desc)
130
- response = requests.get(url, stream=True, headers=headers)
131
- response.raise_for_status()
132
-
133
- total_size = int(response.headers.get('content-length', 0))
134
-
135
- with open(save_path, "wb") as f:
136
- downloaded = 0
137
- for chunk in response.iter_content(chunk_size=8192):
138
- f.write(chunk)
139
- if progress and total_size > 0:
140
- downloaded += len(chunk)
141
- progress(downloaded / total_size, desc=desc)
142
-
143
- return f"Successfully downloaded: {os.path.basename(save_path)}"
144
- except Exception as e:
145
- if os.path.exists(save_path): os.remove(save_path)
146
- return f"Download failed for {os.path.basename(save_path)}: {e}"
147
-
148
- def process_long_prompt(compel_proc, prompt, negative_prompt=""):
149
- try:
150
- conditioning, pooled = compel_proc([prompt, negative_prompt])
151
- return conditioning, pooled
152
- except Exception:
153
- return None, None
154
-
155
- def pre_download_loras(civitai_api_key, *lora_data, progress=gr.Progress(track_tqdm=True)):
156
- civitai_ids = lora_data[0::2]
157
- status_log = []
158
-
159
- active_lora_ids = [cid for cid in civitai_ids if cid and cid.strip()]
160
- if not active_lora_ids:
161
- return "No LoRA IDs provided to download."
162
-
163
- for i, civitai_id in enumerate(active_lora_ids):
164
- version_id = civitai_id.strip()
165
- progress(i / len(active_lora_ids), desc=f"Getting URL for LoRA ID: {version_id}")
166
-
167
- local_lora_path = os.path.join(DOWNLOAD_DIR, f"civitai_{version_id}.safetensors")
168
-
169
- file_info = get_civitai_file_info(version_id)
170
- if not file_info:
171
- status_log.append(f"* LoRA ID {version_id}: Could not get file info from Civitai.")
172
- continue
173
-
174
- download_url = file_info.get('downloadUrl')
175
- if not download_url:
176
- status_log.append(f"* LoRA ID {version_id}: Could not get download link.")
177
- continue
178
-
179
- status = download_file(
180
- download_url,
181
- local_lora_path,
182
- api_key=civitai_api_key,
183
- progress=progress,
184
- desc=f"Downloading LoRA ID: {version_id}"
185
- )
186
- status_log.append(f"* LoRA ID {version_id}: {status}")
187
-
188
- return "\n".join(status_log)
189
-
190
- def _infer_logic(base_model_name, prompt, negative_prompt, seed, batch_size, width, height, guidance_scale, num_inference_steps,
191
- sampler, schedule_type,
192
- civitai_api_key,
193
- *lora_data,
194
- progress=gr.Progress(track_tqdm=True)):
195
-
196
- pipe = None
197
- try:
198
- progress(0, desc=f"Loading model: {base_model_name}")
199
-
200
- if base_model_name in SINGLE_FILE_MODELS:
201
- filename = SINGLE_FILE_MODELS[base_model_name]
202
- print(f"Loading single file: {filename} from repo: {base_model_name}")
203
- local_path = hf_hub_download(repo_id=base_model_name, filename=filename)
204
- pipe = StableDiffusionXLPipeline.from_single_file(
205
- local_path,
206
- torch_dtype=torch.float16,
207
- use_safetensors=True
208
- )
209
- else:
210
- print(f"Loading diffusers model: {base_model_name}")
211
- pipe = StableDiffusionXLPipeline.from_pretrained(
212
- base_model_name,
213
- torch_dtype=torch.float16,
214
- use_safetensors=True
215
- )
216
- pipe.to(device)
217
-
218
- batch_size = int(batch_size)
219
- seed = int(seed)
220
-
221
- pipe.unload_lora_weights()
222
-
223
- scheduler_class = SAMPLER_MAP.get(sampler, EulerAncestralDiscreteScheduler)
224
- scheduler_config = pipe.scheduler.config
225
-
226
- if base_model_name in V_PREDICTION_MODELS:
227
- scheduler_config['prediction_type'] = 'v_prediction'
228
- else:
229
- scheduler_config['prediction_type'] = 'epsilon'
230
-
231
- scheduler_kwargs = {}
232
- if schedule_type == "Default" and sampler == "DPM++ 2M Karras":
233
- scheduler_kwargs['use_karras_sigmas'] = True
234
- elif schedule_type == "Karras":
235
- scheduler_kwargs['use_karras_sigmas'] = True
236
- elif schedule_type == "Uniform":
237
- scheduler_kwargs['use_karras_sigmas'] = False
238
- elif schedule_type == "SGM Uniform":
239
- scheduler_kwargs['algorithm_type'] = 'sgm_uniform'
240
-
241
- pipe.scheduler = scheduler_class.from_config(scheduler_config, **scheduler_kwargs)
242
-
243
- compel_type = ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED
244
- compel = Compel(tokenizer=[pipe.tokenizer, pipe.tokenizer_2], text_encoder=[pipe.text_encoder, pipe.text_encoder_2],
245
- returned_embeddings_type=compel_type, requires_pooled=[False, True], truncate_long_prompts=False)
246
-
247
- civitai_ids, lora_scales = lora_data[0::2], lora_data[1::2]
248
- lora_params = list(zip(civitai_ids, lora_scales))
249
- active_loras, active_lora_names_for_meta = [], []
250
-
251
- for i, (civitai_id, lora_scale) in enumerate(lora_params):
252
- if civitai_id and civitai_id.strip() and lora_scale > 0:
253
- version_id = civitai_id.strip()
254
- local_lora_path = os.path.join(DOWNLOAD_DIR, f"civitai_{version_id}.safetensors")
255
-
256
- if not os.path.exists(local_lora_path):
257
- file_info = get_civitai_file_info(version_id)
258
- if not file_info:
259
- print(f"Could not get file info for Civitai ID {version_id}, skipping.")
260
- continue
261
-
262
- download_url = file_info.get('downloadUrl')
263
- if download_url:
264
- download_file(download_url, local_lora_path, api_key=civitai_api_key, progress=progress, desc=f"Downloading LoRA ID {version_id}")
265
- else:
266
- print(f"Could not get download link for Civitai ID {version_id} during inference, skipping."); continue
267
-
268
- if not os.path.exists(local_lora_path): print(f"LoRA file for ID {version_id} not found, skipping."); continue
269
-
270
- adapter_name = f"lora_{i+1}"
271
- progress((i * 0.1) + 0.05, desc=f"Loading LoRA (ID: {version_id})")
272
- pipe.load_lora_weights(local_lora_path, adapter_name=adapter_name)
273
- active_loras.append((adapter_name, lora_scale))
274
- active_lora_names_for_meta.append(f"LoRA {i+1} (ID: {version_id}, Weight: {lora_scale})")
275
-
276
- if active_loras:
277
- adapter_names, adapter_weights = zip(*active_loras); pipe.set_adapters(list(adapter_names), list(adapter_weights))
278
-
279
- conditioning, pooled = process_long_prompt(compel, prompt, negative_prompt)
280
-
281
- pipe_args = {
282
- "guidance_scale": guidance_scale,
283
- "num_inference_steps": num_inference_steps,
284
- "width": width,
285
- "height": height,
286
- }
287
-
288
- output_images = []
289
- loras_string = f"LoRAs: [{', '.join(active_lora_names_for_meta)}]" if active_lora_names_for_meta else ""
290
-
291
- for i in range(batch_size):
292
- progress(i / batch_size, desc=f"Generating image {i+1}/{batch_size}")
293
-
294
- if i == 0 and seed != -1:
295
- current_seed = seed
296
- else:
297
- current_seed = random.randint(0, MAX_SEED)
298
-
299
- generator = torch.Generator(device=device).manual_seed(current_seed)
300
- pipe_args["generator"] = generator
301
-
302
- if conditioning is not None:
303
- image = pipe(prompt_embeds=conditioning[0:1], pooled_prompt_embeds=pooled[0:1], negative_prompt_embeds=conditioning[1:2], negative_pooled_prompt_embeds=pooled[1:2], **pipe_args).images[0]
304
- else:
305
- image = pipe(prompt=prompt, negative_prompt=negative_prompt, **pipe_args).images[0]
306
-
307
- params_string = f"{prompt}\nNegative prompt: {negative_prompt}\n"
308
- params_string += f"Steps: {num_inference_steps}, Sampler: {sampler}, Schedule type: {schedule_type}, CFG scale: {guidance_scale}, Seed: {current_seed}, Size: {width}x{height}, Base Model: {base_model_name}, {loras_string}".strip()
309
- image.info = {'parameters': params_string}
310
- output_images.append(image)
311
-
312
- return output_images
313
-
314
- except Exception as e:
315
- print(f"An error occurred during generation: {e}"); raise gr.Error(f"Generation failed: {e}")
316
- finally:
317
- if pipe is not None:
318
- pipe.disable_lora()
319
- del pipe
320
- gc.collect()
321
- if torch.cuda.is_available():
322
- torch.cuda.empty_cache()
323
-
324
- def infer(base_model_name, prompt, negative_prompt, seed, batch_size, width, height, guidance_scale, num_inference_steps,
325
- sampler, schedule_type,
326
- civitai_api_key,
327
- zero_gpu_duration,
328
- *lora_data,
329
- progress=gr.Progress(track_tqdm=True)):
330
-
331
- duration = 60
332
- if zero_gpu_duration and int(zero_gpu_duration) > 0:
333
- duration = int(zero_gpu_duration)
334
-
335
- print(f"Using ZeroGPU duration: {duration} seconds")
336
-
337
- decorated_infer_logic = spaces.GPU(duration=duration)(_infer_logic)
338
-
339
- return decorated_infer_logic(
340
- base_model_name, prompt, negative_prompt, seed, batch_size, width, height, guidance_scale, num_inference_steps,
341
- sampler, schedule_type, civitai_api_key, *lora_data, progress=progress
342
- )
343
-
344
- def _parse_parameters(params_text):
345
- data = {'lora_ids': [''] * MAX_LORAS, 'lora_scales': [0.0] * MAX_LORAS}
346
- lines = params_text.strip().split('\n')
347
- data['prompt'] = lines[0]
348
- data['negative_prompt'] = lines[1].replace("Negative prompt:", "").strip() if len(lines) > 1 and lines[1].startswith("Negative prompt:") else ""
349
- params_line = lines[2] if len(lines) > 2 else ""
350
-
351
- def find_param(key, default, cast_type=str):
352
- match = re.search(fr"\b{key}: ([^,]+?)(,|$)", params_line)
353
- if match:
354
- try:
355
- return cast_type(match.group(1).strip())
356
- except (ValueError, TypeError):
357
- return default
358
- return default
359
-
360
- data['steps'] = find_param("Steps", 28, int)
361
- data['sampler'] = find_param("Sampler", DEFAULT_SAMPLER)
362
- data['schedule_type'] = find_param("Schedule type", DEFAULT_SCHEDULE_TYPE)
363
- data['cfg_scale'] = find_param("CFG scale", 7.0, float)
364
- data['seed'] = find_param("Seed", -1, int)
365
- data['base_model'] = find_param("Base Model", MODEL_LIST[0])
366
- data['model_hash'] = find_param("Model hash", None)
367
-
368
- size_match = re.search(r"Size: (\d+)x(\d+)", params_line); data['width'], data['height'] = (int(size_match.group(1)), int(size_match.group(2))) if size_match else (1024, 1024)
369
- if loras_match := re.search(r"LoRAs: \[(.+?)\]", params_line):
370
- for i, (lora_id, lora_scale) in enumerate(re.findall(r"ID: (\d+), Weight: ([\d.]+)", loras_match.group(1))):
371
- if i < MAX_LORAS: data['lora_ids'][i] = lora_id; data['lora_scales'][i] = float(lora_scale)
372
- return data
373
-
374
- def get_png_info(image):
375
- if image is None: return "", "", "Please upload an image first."
376
- params = image.info.get('parameters', None)
377
- if not params: return "", "", "No metadata found in the image."
378
- try:
379
- parsed_data = _parse_parameters(params)
380
- lines = params.strip().split('\n')
381
- other_params_text = lines[2] if len(lines) > 2 else ""
382
- other_params_display = "\n".join([p.strip() for p in other_params_text.split(',')])
383
-
384
- return parsed_data.get('prompt', ''), parsed_data.get('negative_prompt', ''), other_params_display
385
- except Exception as e:
386
- return "", "", f"Error parsing metadata: {e}\n\nRaw metadata:\n{params}"
387
-
388
- def send_info_to_txt2img(image):
389
- if image is None or not (params := image.info.get('parameters', '')):
390
- return [gr.update()] * (12 + MAX_LORAS * 2 + 1)
391
-
392
- data = _parse_parameters(params)
393
-
394
- model_from_hash = HASH_TO_MODEL_MAP.get(data.get('model_hash'))
395
- final_base_model = model_from_hash if model_from_hash else data.get('base_model', MODEL_LIST[0])
396
-
397
- sampler_from_png = data.get('sampler', DEFAULT_SAMPLER)
398
- final_sampler = sampler_from_png if sampler_from_png in SAMPLER_MAP else DEFAULT_SAMPLER
399
-
400
- schedule_from_png = data.get('schedule_type', DEFAULT_SCHEDULE_TYPE)
401
- final_schedule_type = schedule_from_png if schedule_from_png in SCHEDULE_TYPE_MAP else DEFAULT_SCHEDULE_TYPE
402
-
403
- updates = [final_base_model, data['prompt'], data['negative_prompt'], data['seed'], gr.update(), gr.update(), data['width'], data['height'],
404
- data['cfg_scale'], data['steps'], final_sampler, final_schedule_type]
405
-
406
- for i in range(MAX_LORAS): updates.extend([data['lora_ids'][i], data['lora_scales'][i]])
407
- updates.append(gr.Tabs(selected=0))
408
- return updates
409
-
410
- # --- Execute model download on startup ---
411
- download_all_base_models_on_startup()
412
-
413
-
414
- with gr.Blocks(css="#col-container {margin: 0 auto; max-width: 1024px;}") as demo:
415
- gr.Markdown("# Animated SDXL T2I with LoRAs")
416
- with gr.Tabs(elem_id="tabs_container") as tabs:
417
- with gr.TabItem("txt2img", id=0):
418
- gr.Markdown("<div style='background-color: #282828; color: #a0aec0; padding: 10px; border-radius: 5px; margin-bottom: 15px;'>💡 <b>Tip:</b> Pre-downloading LoRAs before 'Run' can maximize ZeroGPU time.</div>")
419
- with gr.Column(elem_id="col-container"):
420
- with gr.Row():
421
- with gr.Column(scale=3):
422
- base_model_name = gr.Dropdown(label="Base Model", choices=MODEL_LIST, value="Laxhar/noobai-XL-Vpred-1.0")
423
- with gr.Column(scale=1):
424
- predownload_lora_button = gr.Button("Pre-download LoRAs")
425
- run_button = gr.Button("Run", variant="primary")
426
-
427
- predownload_status = gr.Markdown("")
428
- prompt = gr.Text(label="Prompt", lines=3, placeholder="Enter your prompt")
429
- negative_prompt = gr.Text(label="Negative prompt", lines=3, placeholder="Enter a negative prompt", value=DEFAULT_NEGATIVE_PROMPT)
430
-
431
- # --- UI Layout ---
432
- with gr.Row():
433
- with gr.Column(scale=2):
434
- with gr.Row():
435
- width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
436
- height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
437
- with gr.Row():
438
- sampler = gr.Dropdown(label="Sampling method", choices=list(SAMPLER_MAP.keys()), value=DEFAULT_SAMPLER)
439
- schedule_type = gr.Dropdown(label="Schedule type", choices=SCHEDULE_TYPE_MAP, value=DEFAULT_SCHEDULE_TYPE)
440
- with gr.Row():
441
- guidance_scale = gr.Slider(label="CFG Scale", minimum=0.0, maximum=20.0, step=0.1, value=7)
442
- num_inference_steps = gr.Slider(label="Sampling steps", minimum=1, maximum=50, step=1, value=28)
443
-
444
- with gr.Column(scale=1):
445
- result = gr.Gallery(label="Result", show_label=False, elem_id="result_gallery", columns=2, object_fit="contain", height="auto")
446
-
447
- with gr.Row():
448
- seed = gr.Number(label="Seed (-1 for random)", value=-1, precision=0)
449
- batch_size = gr.Slider(label="Batch size", minimum=1, maximum=8, step=1, value=1)
450
- zero_gpu_duration = gr.Number(
451
- label="ZeroGPU Duration (s)",
452
- value=None,
453
- placeholder="Default: 60s",
454
- info="Optional: Leave empty for default (60s), max to 120"
455
- )
456
-
457
- with gr.Accordion("LoRA Settings", open=False):
458
- gr.Markdown("⚠️ **Responsible Use Notice:** Please avoid excessive, rapid, or automated (scripted) use of the pre-download LoRA feature. Overt misuse may lead to service disruption. Thank you for your cooperation.")
459
- civitai_api_key = gr.Textbox(label="Optional Civitai API Key", info="Get from your Civitai account settings...", placeholder="Enter your Civitai API Key here", type="password", show_label=True)
460
- gr.Markdown("Find the Model Version ID in the LoRA page URL (e.g., `modelVersionId=12345`) and fill it in below.")
461
- lora_rows, lora_civitai_id_inputs, lora_scale_inputs = [], [], []
462
- for i in range(MAX_LORAS):
463
- with gr.Row(visible=(i == 0)) as row:
464
- lora_civitai_id = gr.Textbox(label=f"LoRA {i+1} - Civitai Model Version ID", placeholder="e.g.: 1834914")
465
- lora_scale = gr.Slider(label=f"Weight {i+1}", minimum=0.0, maximum=2.0, step=0.05, value=0.0)
466
- lora_rows.append(row); lora_civitai_id_inputs.append(lora_civitai_id); lora_scale_inputs.append(lora_scale)
467
- with gr.Row():
468
- add_lora_button = gr.Button("✚ Add LoRA", variant="secondary")
469
- lora_count_state = gr.State(value=1)
470
- all_lora_inputs = [item for pair in zip(lora_civitai_id_inputs, lora_scale_inputs) for item in pair]
471
-
472
- with gr.TabItem("PNG Info", id=1):
473
- with gr.Column(elem_id="col-container"):
474
- gr.Markdown("Upload a generated image to view its generation data.")
475
- info_image_input = gr.Image(type="pil", label="Upload Image")
476
- with gr.Row():
477
- info_get_button = gr.Button("Get Info", variant="secondary")
478
- send_to_txt2img_button = gr.Button("Send to txt2img", variant="primary")
479
- gr.Markdown("### Positive Prompt"); info_prompt_output = gr.Textbox(lines=3, interactive=False, show_label=False)
480
- gr.Markdown("### Negative Prompt"); info_neg_prompt_output = gr.Textbox(lines=3, interactive=False, show_label=False)
481
- gr.Markdown("### Other Parameters"); info_params_output = gr.Textbox(lines=5, interactive=False, show_label=False)
482
-
483
- gr.Markdown("<div style='text-align: center; margin-top: 20px;'>Made by <a href='https://civitai.com/user/RioShiina'>RioShiina</a> with ❤️</div>")
484
-
485
- # --- Event Handlers ---
486
-
487
- def add_lora_row(current_count):
488
- current_count = int(current_count)
489
- if current_count < MAX_LORAS:
490
- updates = {lora_count_state: current_count + 1, lora_rows[current_count]: gr.Row(visible=True)}
491
- if current_count + 1 == MAX_LORAS: updates[add_lora_button] = gr.Button(visible=False)
492
- return updates
493
- return {lora_count_state: current_count}
494
-
495
- def start_lora_predownload():
496
- """This function provides immediate feedback to the user."""
497
- return "⏳ Downloading... please wait. This may take a moment."
498
-
499
- # --- Chain events for immediate feedback ---
500
- predownload_click_event = predownload_lora_button.click(
501
- fn=start_lora_predownload,
502
- inputs=None,
503
- outputs=[predownload_status],
504
- queue=False
505
- ).then(
506
- fn=pre_download_loras,
507
- inputs=[civitai_api_key, *all_lora_inputs],
508
- outputs=[predownload_status]
509
- )
510
-
511
- add_lora_button.click(fn=add_lora_row, inputs=[lora_count_state], outputs=[lora_count_state, add_lora_button] + lora_rows)
512
-
513
- run_button.click(fn=infer,
514
- inputs=[base_model_name, prompt, negative_prompt, seed, batch_size, width, height, guidance_scale, num_inference_steps, sampler, schedule_type, civitai_api_key, zero_gpu_duration, *all_lora_inputs],
515
- outputs=[result])
516
-
517
- info_get_button.click(fn=get_png_info, inputs=[info_image_input], outputs=[info_prompt_output, info_neg_prompt_output, info_params_output])
518
-
519
- txt2img_outputs = [base_model_name, prompt, negative_prompt, seed, batch_size, zero_gpu_duration, width, height, guidance_scale, num_inference_steps, sampler, schedule_type, *all_lora_inputs, tabs]
520
- send_to_txt2img_button.click(fn=send_info_to_txt2img, inputs=[info_image_input], outputs=txt2img_outputs)
521
-
522
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
523
  demo.queue().launch()
 
1
+ import spaces
2
+ import gradio as gr
3
+ import numpy as np
4
+ import PIL.Image
5
+ from PIL import Image, PngImagePlugin
6
+ import random
7
+ from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, DPMSolverMultistepScheduler, DDIMScheduler, UniPCMultistepScheduler, HeunDiscreteScheduler, LMSDiscreteScheduler
8
+ import torch
9
+ from compel import Compel, ReturnedEmbeddingsType
10
+ import requests
11
+ import os
12
+ import re
13
+ import gc
14
+ import hashlib
15
+ from huggingface_hub import hf_hub_download, snapshot_download
16
+ import time
17
+
18
+ # This dummy function is required to pass the Hugging Face Spaces startup check for GPU apps.
19
+ @spaces.GPU(duration=60)
20
+ def dummy_gpu_for_startup():
21
+ print("Dummy function for startup check executed. This is normal.")
22
+ return "Startup check passed."
23
+
24
+ # --- Constants ---
25
+ MAX_LORAS = 5
26
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27
+ MAX_SEED = np.iinfo(np.int64).max
28
+ MAX_IMAGE_SIZE = 1216
29
+ SAMPLER_MAP = {
30
+ "Euler a": EulerAncestralDiscreteScheduler,
31
+ "Euler": EulerDiscreteScheduler,
32
+ "DPM++ 2M Karras": DPMSolverMultistepScheduler,
33
+ "DDIM": DDIMScheduler,
34
+ "UniPC": UniPCMultistepScheduler,
35
+ "Heun": HeunDiscreteScheduler,
36
+ "LMS": LMSDiscreteScheduler,
37
+ }
38
+ SCHEDULE_TYPE_MAP = ["Default", "Karras", "Uniform", "SGM Uniform"]
39
+ LORA_SOURCE_CHOICES = ["Civitai", "TensorArt", "Custom URL", "File"]
40
+ DEFAULT_SCHEDULE_TYPE = "Default"
41
+ DEFAULT_SAMPLER = "Euler a"
42
+ DEFAULT_NEGATIVE_PROMPT = "monochrome, (low quality, worst quality:1.2), 3d, watermark, signature, ugly, poorly drawn,"
43
+ DOWNLOAD_DIR = "/tmp/loras"
44
+ os.makedirs(DOWNLOAD_DIR, exist_ok=True)
45
+
46
+ # --- Model Lists ---
47
+ MODEL_LIST = [
48
+ "dhead/wai-nsfw-illustrious-sdxl-v140-sdxl",
49
+ "Laxhar/noobai-XL-Vpred-1.0",
50
+ "John6666/hassaku-xl-illustrious-v30-sdxl",
51
+ "RedRayz/hikari_noob_v-pred_1.2.2",
52
+ "bluepen5805/noob_v_pencil-XL",
53
+ "Laxhar/noobai-XL-1.1"
54
+ ]
55
+
56
+ # --- Model Display Name Mapping ---
57
+ MODEL_DISPLAY_NAME_MAP = {
58
+ "dhead/wai-nsfw-illustrious-sdxl-v140-sdxl": "WAI0731/wai-nsfw-illustrious-sdxl-v140-sdxl",
59
+ "Laxhar/noobai-XL-Vpred-1.0": "Laxhar/noobai-XL-Vpred-1.0",
60
+ "John6666/hassaku-xl-illustrious-v30-sdxl": "Ikena/hassaku-xl-illustrious-v30-sdxl",
61
+ "RedRayz/hikari_noob_v-pred_1.2.2": "RedRayz/hikari_noob_v-pred_1.2.2",
62
+ "bluepen5805/noob_v_pencil-XL": "bluepen5805/noob_v_pencil-XL",
63
+ "Laxhar/noobai-XL-1.1": "Laxhar/noobai-XL-1.1"
64
+ }
65
+ DISPLAY_NAME_TO_BACKEND_MAP = {v: k for k, v in MODEL_DISPLAY_NAME_MAP.items()}
66
+
67
+ # --- List of V-Prediction Models ---
68
+ V_PREDICTION_MODELS = [
69
+ "Laxhar/noobai-XL-Vpred-1.0",
70
+ "RedRayz/hikari_noob_v-pred_1.2.2",
71
+ "bluepen5805/noob_v_pencil-XL"
72
+ ]
73
+
74
+ # --- Dictionary for single-file models now stores the filename ---
75
+ SINGLE_FILE_MODELS = {
76
+ "bluepen5805/noob_v_pencil-XL": "noob_v_pencil-XL-v3.0.0.safetensors"
77
+ }
78
+
79
+ # --- Model Hash to Name Mapping ---
80
+ HASH_TO_MODEL_MAP = {
81
+ "bdb59bac77": "dhead/wai-nsfw-illustrious-sdxl-v140-sdxl",
82
+ "ea349eeae8": "Laxhar/noobai-XL-Vpred-1.0",
83
+ "b4fb5f829a": "John6666/hassaku-xl-illustrious-v30-sdxl",
84
+ "6681e8e4b1": "Laxhar/noobai-XL-1.1",
85
+ "90b7911a78": "bluepen5805/noob_v_pencil-XL",
86
+ "874170688a": "RedRayz/hikari_noob_v-pred_1.2.2"
87
+ }
88
+ MODEL_TO_HASH_MAP = {v: k for k, v in HASH_TO_MODEL_MAP.items()}
89
+
90
+
91
+ def download_all_base_models_on_startup():
92
+ """Downloads all base models listed in MODEL_LIST when the app starts."""
93
+ print("--- Starting pre-download of all base models ---")
94
+ for model_name in MODEL_LIST:
95
+ try:
96
+ print(f"Downloading: {model_name}...")
97
+ start_time = time.time()
98
+ if model_name in SINGLE_FILE_MODELS:
99
+ filename = SINGLE_FILE_MODELS[model_name]
100
+ hf_hub_download(repo_id=model_name, filename=filename)
101
+ else:
102
+ snapshot_download(repo_id=model_name, ignore_patterns=["*.onnx", "*.flax"])
103
+ end_time = time.time()
104
+ print(f"✅ Successfully downloaded {model_name} in {end_time - start_time:.2f} seconds.")
105
+ except Exception as e:
106
+ print(f"❌ Failed to download {model_name}: {e}")
107
+ finally:
108
+ gc.collect()
109
+ if torch.cuda.is_available():
110
+ torch.cuda.empty_cache()
111
+ print("--- Finished pre-downloading all base models ---")
112
+
113
+ def get_civitai_file_info(version_id):
114
+ """Gets the file metadata for a model version via the Civitai API."""
115
+ api_url = f"https://civitai.com/api/v1/model-versions/{version_id}"
116
+ try:
117
+ response = requests.get(api_url, timeout=10)
118
+ response.raise_for_status()
119
+ data = response.json()
120
+ for file_data in data.get('files', []):
121
+ if file_data.get('type') == 'Model' and file_data['name'].endswith('.safetensors'):
122
+ return file_data
123
+ if data.get('files'):
124
+ return data['files'][0]
125
+ return None
126
+ except Exception as e:
127
+ print(f"Could not get file info from Civitai API: {e}")
128
+ return None
129
+
130
+ def get_tensorart_file_info(model_id):
131
+ """Gets the file metadata for a model via the TensorArt API."""
132
+ api_url = f"https://tensor.art/api/v1/models/{model_id}"
133
+ try:
134
+ response = requests.get(api_url, timeout=10)
135
+ response.raise_for_status()
136
+ data = response.json()
137
+ model_versions = data.get('modelVersions', [])
138
+ if not model_versions: return None
139
+ for file_data in model_versions[0].get('files', []):
140
+ if file_data['name'].endswith('.safetensors'):
141
+ return file_data
142
+ return model_versions[0]['files'][0] if model_versions[0].get('files') else None
143
+ except Exception as e:
144
+ print(f"Could not get file info from TensorArt API: {e}")
145
+ return None
146
+
147
+ def download_file(url, save_path, api_key=None, progress=None, desc=""):
148
+ """Downloads a file, skipping if it already exists."""
149
+ if os.path.exists(save_path):
150
+ return f"File already exists: {os.path.basename(save_path)}"
151
+
152
+ headers = {}
153
+ if api_key and api_key.strip():
154
+ headers['Authorization'] = f'Bearer {api_key}'
155
+
156
+ try:
157
+ if progress: progress(0, desc=desc)
158
+ response = requests.get(url, stream=True, headers=headers, timeout=15)
159
+ response.raise_for_status()
160
+ total_size = int(response.headers.get('content-length', 0))
161
+
162
+ with open(save_path, "wb") as f:
163
+ downloaded = 0
164
+ for chunk in response.iter_content(chunk_size=8192):
165
+ f.write(chunk)
166
+ if progress and total_size > 0:
167
+ downloaded += len(chunk)
168
+ progress(downloaded / total_size, desc=desc)
169
+ return f"Successfully downloaded: {os.path.basename(save_path)}"
170
+ except Exception as e:
171
+ if os.path.exists(save_path): os.remove(save_path)
172
+ return f"Download failed for {os.path.basename(save_path)}: {e}"
173
+
174
+ def get_lora_path(source, id_or_url, civitai_key, tensorart_key, progress):
175
+ """Determines the local path for a LoRA, downloading it if necessary."""
176
+ if not id_or_url or not id_or_url.strip(): return None, "No ID/URL provided."
177
+
178
+ if source == "Civitai":
179
+ version_id = id_or_url.strip()
180
+ local_path = os.path.join(DOWNLOAD_DIR, f"civitai_{version_id}.safetensors")
181
+ if os.path.exists(local_path): return local_path, "File already exists."
182
+ file_info = get_civitai_file_info(version_id)
183
+ api_key_to_use = civitai_key
184
+ source_name = f"Civitai ID {version_id}"
185
+ elif source == "TensorArt":
186
+ model_id = id_or_url.strip()
187
+ local_path = os.path.join(DOWNLOAD_DIR, f"tensorart_{model_id}.safetensors")
188
+ if os.path.exists(local_path): return local_path, "File already exists."
189
+ file_info = get_tensorart_file_info(model_id)
190
+ api_key_to_use = tensorart_key
191
+ source_name = f"TensorArt ID {model_id}"
192
+ elif source == "Custom URL":
193
+ url = id_or_url.strip()
194
+ url_hash = hashlib.md5(url.encode()).hexdigest()
195
+ local_path = os.path.join(DOWNLOAD_DIR, f"custom_{url_hash}.safetensors")
196
+ if os.path.exists(local_path): return local_path, "File already exists."
197
+ file_info = {'downloadUrl': url}
198
+ api_key_to_use = None
199
+ source_name = f"URL {url[:30]}..."
200
+ else:
201
+ return None, "Invalid source."
202
+
203
+ if not file_info: return None, f"Could not get file info for {source_name}."
204
+ download_url = file_info.get('downloadUrl')
205
+ if not download_url: return None, f"Could not get download link for {source_name}."
206
+
207
+ status = download_file(download_url, local_path, api_key=api_key_to_use, progress=progress, desc=f"Downloading {source_name}")
208
+ if "Successfully" in status:
209
+ return local_path, status
210
+ return None, status
211
+
212
+
213
+ def pre_download_loras(civitai_api_key, tensorart_api_key, *lora_data, progress=gr.Progress(track_tqdm=True)):
214
+ sources, ids, scales, files = lora_data[0::4], lora_data[1::4], lora_data[2::4], lora_data[3::4]
215
+ status_log = []
216
+
217
+ active_loras_to_download = [
218
+ (src, lora_id) for src, lora_id, scale, f in zip(sources, ids, scales, files)
219
+ if src in ["Civitai", "TensorArt", "Custom URL"] and lora_id and lora_id.strip() and f is None
220
+ ]
221
+
222
+ if not active_loras_to_download:
223
+ return "No remote LoRAs specified for pre-downloading."
224
+
225
+ for i, (source, lora_id) in enumerate(active_loras_to_download):
226
+ progress(i / len(active_loras_to_download), desc=f"Processing {source} ID: {lora_id}")
227
+ _, status = get_lora_path(source, lora_id, civitai_api_key, tensorart_api_key, progress)
228
+ status_log.append(f"* {source} ID {lora_id}: {status}")
229
+
230
+ return "\n".join(status_log)
231
+
232
+
233
+ def process_long_prompt(compel_proc, prompt, negative_prompt=""):
234
+ """Uses Compel to process prompts that may be too long for the standard tokenizer."""
235
+ try:
236
+ conditioning, pooled = compel_proc([prompt, negative_prompt])
237
+ return conditioning, pooled
238
+ except Exception:
239
+ return None, None
240
+
241
+
242
+ def _infer_logic(base_model_name, prompt, negative_prompt, seed, batch_size, width, height, guidance_scale, num_inference_steps,
243
+ sampler, schedule_type,
244
+ civitai_api_key, tensorart_api_key,
245
+ *lora_data,
246
+ progress=gr.Progress(track_tqdm=True)):
247
+
248
+ pipe = None
249
+ try:
250
+ progress(0, desc=f"Loading model: {base_model_name}")
251
+
252
+ if base_model_name in SINGLE_FILE_MODELS:
253
+ filename = SINGLE_FILE_MODELS[base_model_name]
254
+ local_path = hf_hub_download(repo_id=base_model_name, filename=filename)
255
+ pipe = StableDiffusionXLPipeline.from_single_file(local_path, torch_dtype=torch.float16, use_safetensors=True)
256
+ else:
257
+ pipe = StableDiffusionXLPipeline.from_pretrained(base_model_name, torch_dtype=torch.float16, use_safetensors=True)
258
+ pipe.to(device)
259
+
260
+ batch_size = int(batch_size)
261
+ seed = int(seed)
262
+ pipe.unload_lora_weights()
263
+
264
+ scheduler_class = SAMPLER_MAP.get(sampler, EulerAncestralDiscreteScheduler)
265
+ scheduler_config = pipe.scheduler.config
266
+
267
+ if base_model_name in V_PREDICTION_MODELS: scheduler_config['prediction_type'] = 'v_prediction'
268
+ else: scheduler_config['prediction_type'] = 'epsilon'
269
+
270
+ scheduler_kwargs = {}
271
+ if schedule_type == "Karras" or (schedule_type == "Default" and sampler == "DPM++ 2M Karras"):
272
+ scheduler_kwargs['use_karras_sigmas'] = True
273
+ elif schedule_type == "Uniform": scheduler_kwargs['use_karras_sigmas'] = False
274
+ elif schedule_type == "SGM Uniform": scheduler_kwargs['algorithm_type'] = 'sgm_uniform'
275
+ pipe.scheduler = scheduler_class.from_config(scheduler_config, **scheduler_kwargs)
276
+
277
+ compel = Compel(tokenizer=[pipe.tokenizer, pipe.tokenizer_2], text_encoder=[pipe.text_encoder, pipe.text_encoder_2],
278
+ returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
279
+ requires_pooled=[False, True], truncate_long_prompts=False)
280
+
281
+ sources, ids, scales, files = lora_data[0::4], lora_data[1::4], lora_data[2::4], lora_data[3::4]
282
+ active_loras, active_lora_names_for_meta = [], []
283
+
284
+ for i, (source, lora_id, scale, custom_file) in enumerate(zip(sources, ids, scales, files)):
285
+ if scale > 0:
286
+ local_lora_path = None
287
+ lora_name_for_meta = "Unknown LoRA"
288
+
289
+ if custom_file is not None:
290
+ local_lora_path = custom_file.name
291
+ lora_name_for_meta = f"Custom LoRA ({os.path.basename(local_lora_path)}, Weight: {scale})"
292
+ elif lora_id and lora_id.strip():
293
+ progress(0.05 + (i * 0.05), desc=f"Handling LoRA {i+1} ({source})")
294
+ local_lora_path, _ = get_lora_path(source, lora_id, civitai_api_key, tensorart_api_key, progress)
295
+ lora_name_for_meta = f"{source} LoRA (ID: {lora_id}, Weight: {scale})"
296
+
297
+ if local_lora_path and os.path.exists(local_lora_path):
298
+ adapter_name = f"lora_{i+1}"
299
+ pipe.load_lora_weights(local_lora_path, adapter_name=adapter_name)
300
+ active_loras.append((adapter_name, scale))
301
+ active_lora_names_for_meta.append(lora_name_for_meta)
302
+ else:
303
+ print(f"Skipping LoRA {i+1} as file could not be found or downloaded.")
304
+
305
+ if active_loras:
306
+ adapter_names, adapter_weights = zip(*active_loras)
307
+ pipe.set_adapters(list(adapter_names), list(adapter_weights))
308
+
309
+ conditioning, pooled = process_long_prompt(compel, prompt, negative_prompt)
310
+
311
+ pipe_args = {"guidance_scale": guidance_scale, "num_inference_steps": num_inference_steps, "width": width, "height": height}
312
+ output_images = []
313
+ loras_string = f"LoRAs: [{', '.join(active_lora_names_for_meta)}]" if active_lora_names_for_meta else ""
314
+
315
+ for i in range(batch_size):
316
+ progress(i / batch_size, desc=f"Generating image {i+1}/{batch_size}")
317
+ current_seed = seed if i == 0 and seed != -1 else random.randint(0, MAX_SEED)
318
+ generator = torch.Generator(device=device).manual_seed(current_seed)
319
+ pipe_args["generator"] = generator
320
+
321
+ if conditioning is not None:
322
+ image = pipe(prompt_embeds=conditioning[0:1], pooled_prompt_embeds=pooled[0:1], negative_prompt_embeds=conditioning[1:2], negative_pooled_prompt_embeds=pooled[1:2], **pipe_args).images[0]
323
+ else:
324
+ image = pipe(prompt=prompt, negative_prompt=negative_prompt, **pipe_args).images[0]
325
+
326
+ model_hash = MODEL_TO_HASH_MAP.get(base_model_name, "N/A")
327
+ params_string = f"{prompt}\nNegative prompt: {negative_prompt}\n"
328
+ params_string += f"Steps: {num_inference_steps}, Sampler: {sampler}, Schedule type: {schedule_type}, CFG scale: {guidance_scale}, Seed: {current_seed}, Size: {width}x{height}, Base Model: {base_model_name}, Model hash: {model_hash}, {loras_string}".strip()
329
+ image.info = {'parameters': params_string}
330
+ output_images.append(image)
331
+
332
+ return output_images
333
+
334
+ except Exception as e:
335
+ print(f"An error occurred during generation: {e}")
336
+ error_str = str(e).lower()
337
+ if "dora_scale" in error_str and "not compatible in diffusers" in error_str:
338
+ raise gr.Error("This LoRA appears to be a DoRA model. Diffusers currently has limited support for this format, which may cause errors.")
339
+ raise gr.Error(f"Generation failed: {e}")
340
+ finally:
341
+ if pipe is not None:
342
+ pipe.disable_lora()
343
+ del pipe
344
+ gc.collect()
345
+ if torch.cuda.is_available(): torch.cuda.empty_cache()
346
+
347
+ def infer(base_model_display_name, prompt, negative_prompt, seed, batch_size, width, height, guidance_scale, num_inference_steps,
348
+ sampler, schedule_type, civitai_api_key, tensorart_api_key, zero_gpu_duration, *lora_data,
349
+ progress=gr.Progress(track_tqdm=True)):
350
+
351
+ base_model_name = DISPLAY_NAME_TO_BACKEND_MAP.get(base_model_display_name, base_model_display_name)
352
+ duration = 60
353
+ if zero_gpu_duration and int(zero_gpu_duration) > 0: duration = int(zero_gpu_duration)
354
+ print(f"Using ZeroGPU duration: {duration} seconds")
355
+
356
+ decorated_infer_logic = spaces.GPU(duration=duration)(_infer_logic)
357
+
358
+ return decorated_infer_logic(
359
+ base_model_name, prompt, negative_prompt, seed, batch_size, width, height, guidance_scale, num_inference_steps,
360
+ sampler, schedule_type, civitai_api_key, tensorart_api_key, *lora_data, progress=progress
361
+ )
362
+
363
+ def _parse_parameters(params_text):
364
+ data = {'lora_ids': [''] * MAX_LORAS, 'lora_scales': [0.0] * MAX_LORAS}
365
+ lines = params_text.strip().split('\n')
366
+ data['prompt'] = lines[0]
367
+ data['negative_prompt'] = lines[1].replace("Negative prompt:", "").strip() if len(lines) > 1 and lines[1].startswith("Negative prompt:") else ""
368
+ params_line = lines[2] if len(lines) > 2 else ""
369
+
370
+ def find_param(key, default, cast_type=str):
371
+ match = re.search(fr"\b{key}: ([^,]+?)(,|$)", params_line)
372
+ return cast_type(match.group(1).strip()) if match else default
373
+
374
+ data['steps'] = find_param("Steps", 28, int)
375
+ data['sampler'] = find_param("Sampler", DEFAULT_SAMPLER, str)
376
+ data['schedule_type'] = find_param("Schedule type", DEFAULT_SCHEDULE_TYPE, str)
377
+ data['cfg_scale'] = find_param("CFG scale", 7.0, float)
378
+ data['seed'] = find_param("Seed", -1, int)
379
+ data['base_model'] = find_param("Base Model", MODEL_LIST[0], str)
380
+ data['model_hash'] = find_param("Model hash", None, str)
381
+
382
+ size_match = re.search(r"Size: (\d+)x(\d+)", params_line); data['width'], data['height'] = (int(size_match.group(1)), int(size_match.group(2))) if size_match else (1024, 1024)
383
+ return data
384
+
385
+ def get_png_info(image):
386
+ if image is None: return "", "", "Please upload an image first."
387
+ params = image.info.get('parameters', None)
388
+ if not params: return "", "", "No metadata found in the image."
389
+ try:
390
+ parsed_data = _parse_parameters(params)
391
+ lines = params.strip().split('\n')
392
+ other_params_text = lines[2] if len(lines) > 2 else ""
393
+ other_params_display = "\n".join([p.strip() for p in other_params_text.split(',')])
394
+ return parsed_data.get('prompt', ''), parsed_data.get('negative_prompt', ''), other_params_display
395
+ except Exception as e:
396
+ return "", "", f"Error parsing metadata: {e}\n\nRaw metadata:\n{params}"
397
+
398
+ def send_info_to_txt2img(image):
399
+ if image is None or not (params := image.info.get('parameters', '')):
400
+ num_lora_params = MAX_LORAS * 4
401
+ num_other_params = 12
402
+ num_api_keys = 2
403
+ return [gr.update()] * (num_other_params + num_api_keys + num_lora_params + 1)
404
+
405
+ data = _parse_parameters(params)
406
+
407
+ model_from_hash = HASH_TO_MODEL_MAP.get(data.get('model_hash'))
408
+ backend_base_model = model_from_hash if model_from_hash else data.get('base_model', MODEL_LIST[0])
409
+
410
+ final_display_model = MODEL_DISPLAY_NAME_MAP.get(backend_base_model, backend_base_model)
411
+ final_sampler = data.get('sampler', DEFAULT_SAMPLER)
412
+
413
+ schedule_from_png = data.get('schedule_type', DEFAULT_SCHEDULE_TYPE)
414
+ final_schedule_type = schedule_from_png if schedule_from_png in SCHEDULE_TYPE_MAP else DEFAULT_SCHEDULE_TYPE
415
+
416
+ updates = [final_display_model, data['prompt'], data['negative_prompt'], data['seed'], gr.update(), gr.update(), data['width'], data['height'],
417
+ data['cfg_scale'], data['steps'], final_sampler, final_schedule_type, gr.update(), gr.update()]
418
+
419
+ for i in range(MAX_LORAS):
420
+ updates.extend([gr.update(), gr.update(), gr.update(), gr.update()])
421
+ updates.append(gr.Tabs(selected=0))
422
+ return updates
423
+
424
+ # --- Execute model download on startup ---
425
+ download_all_base_models_on_startup()
426
+
427
+
428
+ with gr.Blocks(css="#col-container {margin: 0 auto; max-width: 1024px;}") as demo:
429
+ gr.Markdown("# Animated SDXL T2I with LoRAs")
430
+ with gr.Tabs(elem_id="tabs_container") as tabs:
431
+ with gr.TabItem("txt2img", id=0):
432
+ gr.Markdown("<div style='background-color: #282828; color: #a0aec0; padding: 10px; border-radius: 5px; margin-bottom: 15px;'>💡 <b>Tip:</b> Pre-downloading LoRAs before 'Run' can maximize ZeroGPU time.</div>")
433
+ with gr.Column(elem_id="col-container"):
434
+ with gr.Row():
435
+ with gr.Column(scale=3):
436
+ default_backend_model = "Laxhar/noobai-XL-Vpred-1.0"
437
+ default_display_name = MODEL_DISPLAY_NAME_MAP.get(default_backend_model, default_backend_model)
438
+ base_model_name_input = gr.Dropdown(label="Base Model", choices=list(MODEL_DISPLAY_NAME_MAP.values()), value=default_display_name)
439
+ with gr.Column(scale=1):
440
+ predownload_lora_button = gr.Button("Pre-download LoRAs")
441
+ run_button = gr.Button("Run", variant="primary")
442
+
443
+ predownload_status = gr.Markdown("")
444
+ prompt = gr.Text(label="Prompt", lines=3, placeholder="Enter your prompt")
445
+ negative_prompt = gr.Text(label="Negative prompt", lines=3, placeholder="Enter a negative prompt", value=DEFAULT_NEGATIVE_PROMPT)
446
+
447
+ with gr.Row():
448
+ with gr.Column(scale=2):
449
+ with gr.Row():
450
+ width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
451
+ height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
452
+ with gr.Row():
453
+ sampler = gr.Dropdown(label="Sampling method", choices=list(SAMPLER_MAP.keys()), value=DEFAULT_SAMPLER)
454
+ schedule_type = gr.Dropdown(label="Schedule type", choices=SCHEDULE_TYPE_MAP, value=DEFAULT_SCHEDULE_TYPE)
455
+ with gr.Row():
456
+ guidance_scale = gr.Slider(label="CFG Scale", minimum=0.0, maximum=20.0, step=0.1, value=7)
457
+ num_inference_steps = gr.Slider(label="Sampling steps", minimum=1, maximum=50, step=1, value=28)
458
+ with gr.Column(scale=1):
459
+ result = gr.Gallery(label="Result", show_label=False, elem_id="result_gallery", columns=2, object_fit="contain", height="auto")
460
+
461
+ with gr.Row():
462
+ seed = gr.Number(label="Seed (-1 for random)", value=-1, precision=0)
463
+ batch_size = gr.Slider(label="Batch size", minimum=1, maximum=8, step=1, value=1)
464
+ zero_gpu_duration = gr.Number(label="ZeroGPU Duration (s)", value=None, placeholder="Default: 60s", info="Optional: Leave empty for default (60s), max to 120")
465
+
466
+ with gr.Accordion("LoRA Settings", open=False):
467
+ gr.Markdown("⚠️ **Responsible Use Notice:** Please avoid excessive, rapid, or automated (scripted) use of the pre-download LoRA feature. Overt misuse may lead to service disruption. Thank you for your cooperation.")
468
+
469
+ gr.Markdown("For LoRAs that require login to download, you may need to enter the corresponding API Key.")
470
+ with gr.Row():
471
+ with gr.Column(scale=1):
472
+ gr.Markdown("**Civitai API Key**")
473
+ civitai_api_key = gr.Textbox(show_label=False, placeholder="Enter your Civitai API Key here", type="password", container=False)
474
+ with gr.Column(scale=1):
475
+ gr.Markdown("**TensorArt API Key**")
476
+ tensorart_api_key = gr.Textbox(show_label=False, placeholder="Enter your TensorArt API Key here", type="password", container=False)
477
+
478
+ gr.Markdown("---")
479
+ gr.Markdown("For each LoRA, choose a source, provide an ID/URL, or upload a file.")
480
+
481
+ gr.Markdown("""
482
+ <div style='background-color: #282828; color: #a0aec0; padding: 10px; border-radius: 5px; margin-top: 10px; margin-bottom: 15px;'>
483
+ <b>Input Examples:</b>
484
+ <ul>
485
+ <li><b>Civitai:</b> Enter the <b>Model Version ID</b>, not the Model ID. Example: <code>133755</code> (Found in the URL, e.g., <code>civitai.com/models/122136?modelVersionId=<b>133755</b></code>)</li>
486
+ <li><b>TensorArt:</b> Enter the <b>Model ID</b>. Example: <code>706684852832599558</code> (Found in the URL, e.g., <code>tensor.art/models/<b>706684852832599558</b></code>)</li>
487
+ <li><b>Custom URL:</b> Provide a direct download link to a <code>.safetensors</code> file. Example: <code>https://huggingface.co/path/to/your/lora.safetensors</code></li>
488
+ <li><b>File:</b> Use the "Upload" button. The source will be set automatically.</li>
489
+ </ul>
490
+ </div>
491
+ """)
492
+
493
+ gr.Markdown("""
494
+ <div style='background-color: #282828; color: #a0aec0; padding: 10px; border-radius: 5px; margin-bottom: 15px;'>
495
+ <b>TODO:</b>
496
+ <ul style='margin-bottom: 0;'>
497
+ <li>When uploading a local LoRA, the page may not respond, but it is transferring. Please be patient. This issue is pending a fix.</li>
498
+ </ul>
499
+ </div>
500
+ """)
501
+
502
+ lora_rows = []
503
+ lora_source_inputs, lora_id_inputs, lora_scale_inputs, lora_upload_buttons = [], [], [], []
504
+
505
+ for i in range(MAX_LORAS):
506
+ with gr.Row(visible=(i == 0)) as row:
507
+ with gr.Column(scale=1, min_width=120):
508
+ lora_source = gr.Dropdown(label=f"LoRA {i+1} Source", choices=LORA_SOURCE_CHOICES, value="Civitai")
509
+ with gr.Column(scale=2, min_width=160):
510
+ lora_id = gr.Textbox(label="ID / URL / Uploaded File", placeholder="e.g.: 133755")
511
+ with gr.Column(scale=2, min_width=220):
512
+ lora_scale = gr.Slider(label="Weight", minimum=0.0, maximum=2.0, step=0.05, value=0.0)
513
+ with gr.Column(scale=1, min_width=80):
514
+ lora_upload = gr.UploadButton("Upload", file_types=[".safetensors"])
515
+
516
+ lora_rows.append(row)
517
+ lora_source_inputs.append(lora_source)
518
+ lora_id_inputs.append(lora_id)
519
+ lora_scale_inputs.append(lora_scale)
520
+ lora_upload_buttons.append(lora_upload)
521
+
522
+ lora_upload.upload(
523
+ fn=lambda f: (os.path.basename(f.name), "File") if f else (gr.update(), gr.update()),
524
+ inputs=[lora_upload],
525
+ outputs=[lora_id, lora_source]
526
+ )
527
+
528
+ with gr.Row():
529
+ add_lora_button = gr.Button("✚ Add LoRA", variant="secondary")
530
+ delete_lora_button = gr.Button("➖ Delete LoRA", variant="secondary", visible=False)
531
+
532
+ lora_count_state = gr.State(value=1)
533
+ all_lora_components_flat = [item for sublist in zip(lora_source_inputs, lora_id_inputs, lora_scale_inputs, lora_upload_buttons) for item in sublist]
534
+
535
+
536
+ with gr.TabItem("PNG Info", id=1):
537
+ with gr.Column(elem_id="col-container"):
538
+ gr.Markdown("Upload a generated image to view its generation data.")
539
+ info_image_input = gr.Image(type="pil", label="Upload Image")
540
+ with gr.Row():
541
+ info_get_button = gr.Button("Get Info", variant="secondary")
542
+ send_to_txt2img_button = gr.Button("Send to txt2img", variant="primary")
543
+ gr.Markdown("### Positive Prompt"); info_prompt_output = gr.Textbox(lines=3, interactive=False, show_label=False)
544
+ gr.Markdown("### Negative Prompt"); info_neg_prompt_output = gr.Textbox(lines=3, interactive=False, show_label=False)
545
+ gr.Markdown("### Other Parameters"); info_params_output = gr.Textbox(lines=5, interactive=False, show_label=False)
546
+
547
+ gr.Markdown("<div style='text-align: center; margin-top: 20px;'>Made by <a href='https://civitai.com/user/RioShiina'>RioShiina</a> with ❤️</div>")
548
+
549
+ # --- Event Handlers ---
550
+ def add_lora_row(current_count):
551
+ current_count = int(current_count)
552
+ if current_count < MAX_LORAS:
553
+ return {
554
+ lora_count_state: current_count + 1,
555
+ lora_rows[current_count]: gr.update(visible=True),
556
+ delete_lora_button: gr.update(visible=True),
557
+ add_lora_button: gr.update(visible=False) if (current_count + 1 == MAX_LORAS) else gr.update(visible=True)
558
+ }
559
+ return {}
560
+
561
+ def delete_lora_row(current_count):
562
+ current_count = int(current_count)
563
+ if current_count > 1:
564
+ row_index_to_hide = current_count - 1
565
+ return {
566
+ lora_count_state: current_count - 1,
567
+ lora_rows[row_index_to_hide]: gr.update(visible=False),
568
+ lora_id_inputs[row_index_to_hide]: gr.update(value=""),
569
+ lora_scale_inputs[row_index_to_hide]: gr.update(value=0.0),
570
+ add_lora_button: gr.update(visible=True),
571
+ delete_lora_button: gr.update(visible=False) if (current_count - 1 == 1) else gr.update(visible=True)
572
+ }
573
+ return {}
574
+
575
+ def start_lora_predownload():
576
+ return "⏳ Downloading... please wait. This may take a moment."
577
+
578
+ predownload_lora_button.click(
579
+ fn=start_lora_predownload,
580
+ inputs=None,
581
+ outputs=[predownload_status],
582
+ queue=False
583
+ ).then(
584
+ fn=pre_download_loras,
585
+ inputs=[civitai_api_key, tensorart_api_key, *all_lora_components_flat],
586
+ outputs=[predownload_status]
587
+ )
588
+
589
+ add_lora_button.click(
590
+ fn=add_lora_row,
591
+ inputs=[lora_count_state],
592
+ outputs=[lora_count_state, add_lora_button, delete_lora_button, *lora_rows]
593
+ )
594
+
595
+ delete_lora_button.click(
596
+ fn=delete_lora_row,
597
+ inputs=[lora_count_state],
598
+ outputs=[
599
+ lora_count_state,
600
+ add_lora_button,
601
+ delete_lora_button,
602
+ *lora_rows,
603
+ *lora_id_inputs,
604
+ *lora_scale_inputs
605
+ ]
606
+ )
607
+
608
+ run_button_inputs = [base_model_name_input, prompt, negative_prompt, seed, batch_size, width, height, guidance_scale, num_inference_steps, sampler, schedule_type, civitai_api_key, tensorart_api_key, zero_gpu_duration, *all_lora_components_flat]
609
+ run_button.click(fn=infer, inputs=run_button_inputs, outputs=[result])
610
+
611
+ info_get_button.click(fn=get_png_info, inputs=[info_image_input], outputs=[info_prompt_output, info_neg_prompt_output, info_params_output])
612
+
613
+ txt2img_outputs = [
614
+ base_model_name_input, prompt, negative_prompt, seed, batch_size,
615
+ zero_gpu_duration, width, height, guidance_scale, num_inference_steps,
616
+ sampler, schedule_type, civitai_api_key, tensorart_api_key,
617
+ *all_lora_components_flat, tabs
618
+ ]
619
+ send_to_txt2img_button.click(fn=send_info_to_txt2img, inputs=[info_image_input], outputs=txt2img_outputs)
620
+
621
  demo.queue().launch()