RioShiina commited on
Commit
839b683
·
verified ·
1 Parent(s): 5fcd4f7

feat: Removes the "Pre-download Base Model" button and now automatically downloads all base models on deployment.

Browse files
Files changed (1) hide show
  1. app.py +36 -45
app.py CHANGED
@@ -11,7 +11,8 @@ import requests
11
  import os
12
  import re
13
  import gc
14
- from huggingface_hub import hf_hub_download
 
15
 
16
  # This dummy function is required to pass the Hugging Face Spaces startup check for GPU apps.
17
  @spaces.GPU(duration=60)
@@ -72,6 +73,32 @@ HASH_TO_MODEL_MAP = {
72
  "874170688a": "RedRayz/hikari_noob_v-pred_1.2.2"
73
  }
74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  def get_civitai_file_info(version_id):
76
  """Gets the file metadata for a model version via the Civitai API."""
77
  api_url = f"https://civitai.com/api/v1/model-versions/{version_id}"
@@ -125,42 +152,6 @@ def process_long_prompt(compel_proc, prompt, negative_prompt=""):
125
  except Exception:
126
  return None, None
127
 
128
- def pre_download_base_model(model_name, progress=gr.Progress(track_tqdm=True)):
129
- if not model_name:
130
- return "Please select a base model to download."
131
-
132
- status_log = []
133
- try:
134
- progress(0, desc=f"Starting download for: {model_name}")
135
-
136
- if model_name in SINGLE_FILE_MODELS:
137
- filename = SINGLE_FILE_MODELS[model_name]
138
- print(f"Pre-downloading single file: {filename} from repo: {model_name}")
139
- local_path = hf_hub_download(repo_id=model_name, filename=filename)
140
- pipe = StableDiffusionXLPipeline.from_single_file(
141
- local_path,
142
- torch_dtype=torch.float16,
143
- use_safetensors=True
144
- )
145
- else:
146
- print(f"Pre-downloading diffusers model: {model_name}")
147
- pipe = StableDiffusionXLPipeline.from_pretrained(
148
- model_name,
149
- torch_dtype=torch.float16,
150
- use_safetensors=True
151
- )
152
-
153
- status_log.append(f"✅ Successfully downloaded {model_name}")
154
- del pipe
155
- except Exception as e:
156
- status_log.append(f"❌ Failed to download {model_name}: {e}")
157
- finally:
158
- gc.collect()
159
- if torch.cuda.is_available():
160
- torch.cuda.empty_cache()
161
-
162
- return "\n".join(status_log)
163
-
164
  def pre_download_loras(civitai_api_key, *lora_data, progress=gr.Progress(track_tqdm=True)):
165
  civitai_ids = lora_data[0::2]
166
  status_log = []
@@ -416,20 +407,21 @@ def send_info_to_txt2img(image):
416
  updates.append(gr.Tabs(selected=0))
417
  return updates
418
 
 
 
 
 
419
  with gr.Blocks(css="#col-container {margin: 0 auto; max-width: 1024px;}") as demo:
420
  gr.Markdown("# Animated SDXL T2I with LoRAs")
421
  with gr.Tabs(elem_id="tabs_container") as tabs:
422
  with gr.TabItem("txt2img", id=0):
423
- gr.Markdown("<div style='background-color: #282828; color: #a0aec0; padding: 10px; border-radius: 5px; margin-bottom: 15px;'>💡 <b>Tip:</b> Pre-downloading the base model and LoRAs before clicking 'Run' can maximize your ZeroGPU time.</div>")
424
  with gr.Column(elem_id="col-container"):
425
  with gr.Row():
426
  with gr.Column(scale=3):
427
  base_model_name = gr.Dropdown(label="Base Model", choices=MODEL_LIST, value="Laxhar/noobai-XL-Vpred-1.0")
428
- with gr.Column(scale=2):
429
- with gr.Row():
430
- predownload_base_model_button = gr.Button("Pre-download Base Model")
431
- predownload_lora_button = gr.Button("Pre-download LoRAs")
432
- with gr.Column(scale=1, min_width=100):
433
  run_button = gr.Button("Run", variant="primary")
434
 
435
  predownload_status = gr.Markdown("")
@@ -483,7 +475,7 @@ with gr.Blocks(css="#col-container {margin: 0 auto; max-width: 1024px;}") as dem
483
  info_image_input = gr.Image(type="pil", label="Upload Image")
484
  with gr.Row():
485
  info_get_button = gr.Button("Get Info", variant="secondary")
486
- send_to_txt2img_button = gr.Button("Send to Txt-to-Image", variant="primary")
487
  gr.Markdown("### Positive Prompt"); info_prompt_output = gr.Textbox(lines=3, interactive=False, show_label=False)
488
  gr.Markdown("### Negative Prompt"); info_neg_prompt_output = gr.Textbox(lines=3, interactive=False, show_label=False)
489
  gr.Markdown("### Other Parameters"); info_params_output = gr.Textbox(lines=5, interactive=False, show_label=False)
@@ -500,7 +492,6 @@ with gr.Blocks(css="#col-container {margin: 0 auto; max-width: 1024px;}") as dem
500
 
501
  add_lora_button.click(fn=add_lora_row, inputs=[lora_count_state], outputs=[lora_count_state, add_lora_button] + lora_rows)
502
 
503
- predownload_base_model_button.click(fn=pre_download_base_model, inputs=[base_model_name], outputs=[predownload_status])
504
  predownload_lora_button.click(fn=pre_download_loras, inputs=[civitai_api_key, *all_lora_inputs], outputs=[predownload_status])
505
 
506
  run_button.click(fn=infer,
 
11
  import os
12
  import re
13
  import gc
14
+ from huggingface_hub import hf_hub_download, snapshot_download
15
+ import time
16
 
17
  # This dummy function is required to pass the Hugging Face Spaces startup check for GPU apps.
18
  @spaces.GPU(duration=60)
 
73
  "874170688a": "RedRayz/hikari_noob_v-pred_1.2.2"
74
  }
75
 
76
+ def download_all_base_models_on_startup():
77
+ """Downloads all base models listed in MODEL_LIST when the app starts."""
78
+ print("--- Starting pre-download of all base models ---")
79
+ for model_name in MODEL_LIST:
80
+ try:
81
+ print(f"Downloading: {model_name}...")
82
+ start_time = time.time()
83
+ # Handle single-file models
84
+ if model_name in SINGLE_FILE_MODELS:
85
+ filename = SINGLE_FILE_MODELS[model_name]
86
+ hf_hub_download(repo_id=model_name, filename=filename)
87
+ # Handle standard diffusers models
88
+ else:
89
+ snapshot_download(repo_id=model_name, ignore_patterns=["*.onnx", "*.flax"])
90
+ end_time = time.time()
91
+ print(f"✅ Successfully downloaded {model_name} in {end_time - start_time:.2f} seconds.")
92
+ except Exception as e:
93
+ print(f"❌ Failed to download {model_name}: {e}")
94
+ finally:
95
+ # Clean up to conserve memory
96
+ gc.collect()
97
+ if torch.cuda.is_available():
98
+ torch.cuda.empty_cache()
99
+ print("--- Finished pre-downloading all base models ---")
100
+
101
+
102
  def get_civitai_file_info(version_id):
103
  """Gets the file metadata for a model version via the Civitai API."""
104
  api_url = f"https://civitai.com/api/v1/model-versions/{version_id}"
 
152
  except Exception:
153
  return None, None
154
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  def pre_download_loras(civitai_api_key, *lora_data, progress=gr.Progress(track_tqdm=True)):
156
  civitai_ids = lora_data[0::2]
157
  status_log = []
 
407
  updates.append(gr.Tabs(selected=0))
408
  return updates
409
 
410
+ # --- Execute model download on startup ---
411
+ download_all_base_models_on_startup()
412
+
413
+
414
  with gr.Blocks(css="#col-container {margin: 0 auto; max-width: 1024px;}") as demo:
415
  gr.Markdown("# Animated SDXL T2I with LoRAs")
416
  with gr.Tabs(elem_id="tabs_container") as tabs:
417
  with gr.TabItem("txt2img", id=0):
418
+ gr.Markdown("<div style='background-color: #282828; color: #a0aec0; padding: 10px; border-radius: 5px; margin-bottom: 15px;'>💡 <b>Tip:</b> Pre-downloading LoRAs before 'Run' can maximize ZeroGPU time.</div>")
419
  with gr.Column(elem_id="col-container"):
420
  with gr.Row():
421
  with gr.Column(scale=3):
422
  base_model_name = gr.Dropdown(label="Base Model", choices=MODEL_LIST, value="Laxhar/noobai-XL-Vpred-1.0")
423
+ with gr.Column(scale=1):
424
+ predownload_lora_button = gr.Button("Pre-download LoRAs")
 
 
 
425
  run_button = gr.Button("Run", variant="primary")
426
 
427
  predownload_status = gr.Markdown("")
 
475
  info_image_input = gr.Image(type="pil", label="Upload Image")
476
  with gr.Row():
477
  info_get_button = gr.Button("Get Info", variant="secondary")
478
+ send_to_txt2img_button = gr.Button("Send to txt2img", variant="primary")
479
  gr.Markdown("### Positive Prompt"); info_prompt_output = gr.Textbox(lines=3, interactive=False, show_label=False)
480
  gr.Markdown("### Negative Prompt"); info_neg_prompt_output = gr.Textbox(lines=3, interactive=False, show_label=False)
481
  gr.Markdown("### Other Parameters"); info_params_output = gr.Textbox(lines=5, interactive=False, show_label=False)
 
492
 
493
  add_lora_button.click(fn=add_lora_row, inputs=[lora_count_state], outputs=[lora_count_state, add_lora_button] + lora_rows)
494
 
 
495
  predownload_lora_button.click(fn=pre_download_loras, inputs=[civitai_api_key, *all_lora_inputs], outputs=[predownload_status])
496
 
497
  run_button.click(fn=infer,