danhtran2mind commited on
Commit
7946a9d
·
verified ·
1 Parent(s): f94fdf4

Upload 84 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +21 -0
  2. .python-version +1 -0
  3. CODE_OF_CONDUCT.md +1 -0
  4. CONTRIBUTING.md +1 -0
  5. LICENSE +21 -0
  6. SECURITY.md +1 -0
  7. SUPPORT.md +1 -0
  8. apps/gradio_app.py +33 -0
  9. apps/gradio_app/__init__.py +0 -0
  10. apps/gradio_app/aa.py +603 -0
  11. apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning/1/config.json +11 -0
  12. apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning/1/result.png +3 -0
  13. apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning/2/config.json +11 -0
  14. apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning/2/result.png +3 -0
  15. apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning/3/config.json +11 -0
  16. apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning/3/result.png +3 -0
  17. apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning/4/config.json +11 -0
  18. apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning/4/result.png +3 -0
  19. apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-LoRA/1/config.json +13 -0
  20. apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-LoRA/1/result.png +3 -0
  21. apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-LoRA/2/config.json +13 -0
  22. apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-LoRA/2/result.png +3 -0
  23. apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-LoRA/3/config.json +13 -0
  24. apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-LoRA/3/result.png +3 -0
  25. apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-LoRA/4/config.json +13 -0
  26. apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-LoRA/4/result.png +3 -0
  27. apps/gradio_app/assets/examples/default_image.png +3 -0
  28. apps/gradio_app/config_loader.py +5 -0
  29. apps/gradio_app/example_handler.py +60 -0
  30. apps/gradio_app/gui_components.py +120 -0
  31. apps/gradio_app/image_generator.py +54 -0
  32. apps/gradio_app/old-image_generator.py +77 -0
  33. apps/gradio_app/project_info.py +36 -0
  34. apps/gradio_app/setup_scripts.py +64 -0
  35. apps/gradio_app/static/styles.css +213 -0
  36. apps/old-gradio_app.py +261 -0
  37. apps/old2-gradio_app.py +376 -0
  38. apps/old3-gradio_app.py +438 -0
  39. apps/old4-gradio_app.py +548 -0
  40. apps/old5-gradio_app.py +258 -0
  41. assets/.gitkeep +1 -0
  42. assets/demo_image.png +3 -0
  43. assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning/1/config.json +11 -0
  44. assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning/1/result.png +3 -0
  45. assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning/2/config.json +11 -0
  46. assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning/2/result.png +3 -0
  47. assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning/3/config.json +11 -0
  48. assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning/3/result.png +3 -0
  49. assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning/4/config.json +11 -0
  50. assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning/4/result.png +3 -0
.gitattributes CHANGED
@@ -33,3 +33,24 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ apps/gradio_app/assets/examples/default_image.png filter=lfs diff=lfs merge=lfs -text
37
+ apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning/1/result.png filter=lfs diff=lfs merge=lfs -text
38
+ apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning/2/result.png filter=lfs diff=lfs merge=lfs -text
39
+ apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning/3/result.png filter=lfs diff=lfs merge=lfs -text
40
+ apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning/4/result.png filter=lfs diff=lfs merge=lfs -text
41
+ apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-LoRA/1/result.png filter=lfs diff=lfs merge=lfs -text
42
+ apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-LoRA/2/result.png filter=lfs diff=lfs merge=lfs -text
43
+ apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-LoRA/3/result.png filter=lfs diff=lfs merge=lfs -text
44
+ apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-LoRA/4/result.png filter=lfs diff=lfs merge=lfs -text
45
+ assets/demo_image.png filter=lfs diff=lfs merge=lfs -text
46
+ assets/examples/default_image.png filter=lfs diff=lfs merge=lfs -text
47
+ assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning/1/result.png filter=lfs diff=lfs merge=lfs -text
48
+ assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning/2/result.png filter=lfs diff=lfs merge=lfs -text
49
+ assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning/3/result.png filter=lfs diff=lfs merge=lfs -text
50
+ assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning/4/result.png filter=lfs diff=lfs merge=lfs -text
51
+ assets/examples/Ghibli-Stable-Diffusion-2.1-LoRA/1/result.png filter=lfs diff=lfs merge=lfs -text
52
+ assets/examples/Ghibli-Stable-Diffusion-2.1-LoRA/2/result.png filter=lfs diff=lfs merge=lfs -text
53
+ assets/examples/Ghibli-Stable-Diffusion-2.1-LoRA/3/result.png filter=lfs diff=lfs merge=lfs -text
54
+ assets/examples/Ghibli-Stable-Diffusion-2.1-LoRA/4/result.png filter=lfs diff=lfs merge=lfs -text
55
+ tests/test_data/ghibli_style_output_full_finetuning.png filter=lfs diff=lfs merge=lfs -text
56
+ tests/test_data/ghibli_style_output_lora.png filter=lfs diff=lfs merge=lfs -text
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.10.12
CODE_OF_CONDUCT.md ADDED
@@ -0,0 +1 @@
 
 
1
+
CONTRIBUTING.md ADDED
@@ -0,0 +1 @@
 
 
1
+
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Danh Tran
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
SECURITY.md ADDED
@@ -0,0 +1 @@
 
 
1
+
SUPPORT.md ADDED
@@ -0,0 +1 @@
 
 
1
+
apps/gradio_app.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import subprocess
3
+ import os
4
+ import torch
5
+ from gradio_app.gui_components import create_gui
6
+ from gradio_app.config_loader import load_model_configs
7
+
8
+ def run_setup_script():
9
+ setup_script = os.path.join(os.path.dirname(__file__),
10
+ "gradio_app", "setup_scripts.py")
11
+ try:
12
+ result = subprocess.run(["python", setup_script], capture_output=True, text=True, check=True)
13
+ return result.stdout
14
+ except subprocess.CalledProcessError as e:
15
+ print(f"Setup script failed with error: {e.stderr}")
16
+ return f"Setup script failed: {e.stderr}"
17
+
18
+ def main():
19
+ parser = argparse.ArgumentParser(description="Ghibli Stable Diffusion Synthesisr")
20
+ parser.add_argument("--config_path", type=str, default="configs/model_ckpts.yaml")
21
+ parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
22
+ parser.add_argument("--port", type=int, default=7860)
23
+ parser.add_argument("--share", action="store_true")
24
+ args = parser.parse_args()
25
+ print("Running setup script...")
26
+ run_setup_script()
27
+ print("Starting Gradio app...")
28
+ model_configs = load_model_configs(args.config_path)
29
+ demo = create_gui(model_configs, args.device)
30
+ demo.launch(server_port=args.port, share=args.share)
31
+
32
+ if __name__ == "__main__":
33
+ main()
apps/gradio_app/__init__.py ADDED
File without changes
apps/gradio_app/aa.py ADDED
@@ -0,0 +1,603 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ from typing import Union, List
4
+ from pathlib import Path
5
+ import os
6
+ import gradio as gr
7
+ import torch
8
+ from PIL import Image
9
+ import numpy as np
10
+ from transformers import CLIPTextModel, CLIPTokenizer
11
+ from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler, StableDiffusionPipeline
12
+ from tqdm import tqdm
13
+ import yaml
14
+
15
+ def load_model_configs(config_path: str = "configs/model_ckpts.yaml") -> dict:
16
+ """
17
+ Load model configurations from a YAML file.
18
+ Returns a dictionary with model IDs and their details.
19
+ """
20
+ try:
21
+ with open(config_path, 'r') as f:
22
+ configs = yaml.safe_load(f)
23
+ return {cfg['model_id']: cfg for cfg in configs}
24
+ except (IOError, yaml.YAMLError) as e:
25
+ raise ValueError(f"Error loading {config_path}: {e}")
26
+
27
+ def get_examples(examples_dir: Union[str, List[str]] = None,
28
+ use_lora: Union[bool, None] = None) -> List:
29
+ # Convert single string to list
30
+ directories = [examples_dir] if isinstance(examples_dir, str) else examples_dir or []
31
+
32
+ # Validate directories
33
+ valid_dirs = [d for d in directories if os.path.isdir(d)]
34
+ if not valid_dirs:
35
+ print("Error: No valid directories found, using provided examples")
36
+ return get_provided_examples(use_lora)
37
+
38
+ examples = []
39
+ for dir_path in valid_dirs:
40
+ # Get sorted subdirectories
41
+ subdirs = sorted(
42
+ os.path.join(dir_path, d) for d in os.listdir(dir_path) if os.path.isdir(os.path.join(dir_path, d))
43
+ )
44
+
45
+ for subdir in subdirs:
46
+ config_path = os.path.join(subdir, "config.json")
47
+ image_path = os.path.join(subdir, "result.png")
48
+
49
+ if not (os.path.isfile(config_path) and os.path.isfile(image_path)):
50
+ print(f"Error: Missing config.json or result.png in {subdir}")
51
+ continue
52
+
53
+ try:
54
+ with open(config_path, 'r') as f:
55
+ config = json.load(f)
56
+ except (json.JSONDecodeError, IOError) as e:
57
+ print(f"Error reading {config_path}: {e}")
58
+ continue
59
+
60
+ required_keys = ["prompt", "height", "width", "num_inference_steps", "guidance_scale", "seed", "image"]
61
+ if config.get("use_lora", False):
62
+ required_keys.extend(["lora_model_id", "base_model_id", "lora_rank", "lora_scale"])
63
+ else:
64
+ required_keys.append("finetune_model_id")
65
+
66
+ if missing_keys := set(required_keys) - set(config.keys()):
67
+ print(f"Error: Missing keys in {config_path}: {', '.join(missing_keys)}")
68
+ continue
69
+
70
+ if config["image"] != "result.png":
71
+ print(f"Error: Image key in {config_path} does not match 'result.png'")
72
+ continue
73
+
74
+ try:
75
+ Image.open(image_path).verify()
76
+ image = Image.open(image_path) # Re-open after verify
77
+ except Exception as e:
78
+ print(f"Error: Invalid image {image_path}: {e}")
79
+ continue
80
+
81
+ if use_lora is not None and config.get("use_lora", False) != use_lora:
82
+ print(f"DEBUG: Skipping {config_path} due to use_lora mismatch (expected {use_lora}, got {config.get('use_lora', False)})")
83
+ continue
84
+
85
+ # Build example list based on use_lora
86
+ example = [
87
+ config["prompt"],
88
+ config["height"],
89
+ config["width"],
90
+ config["num_inference_steps"],
91
+ config["guidance_scale"],
92
+ config["seed"],
93
+ image,
94
+ # config.get("use_lora", False)
95
+ ]
96
+ if config.get("use_lora", False):
97
+ example.extend([
98
+ config["lora_model_id"],
99
+ config["base_model_id"],
100
+ config["lora_rank"],
101
+ config["lora_scale"]
102
+ ])
103
+ else:
104
+ example.append(config["finetune_model_id"])
105
+
106
+ examples.append(example)
107
+ print(f"DEBUG: Loaded example from {config_path}: {example[:6]}")
108
+
109
+ return examples or get_provided_examples(use_lora)
110
+
111
+ def get_provided_examples(use_lora: bool = False) -> list:
112
+ example1_image = None
113
+ example2_image = None
114
+ # Attempt to load example images
115
+ if use_lora:
116
+ try:
117
+ example2_path = "apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-LoRA/1/result.png"
118
+ if os.path.exists(example2_path):
119
+ example2_image = Image.open(example2_path)
120
+ except Exception as e:
121
+ print(f"Failed to load example2 image: {e}")
122
+ output = [list({
123
+ "prompt": "a cat is laying on a sofa in Ghibli style",
124
+ "width": 512,
125
+ "height": 768,
126
+ "steps": 100,
127
+ "cfg_scale": 10.0,
128
+ "seed": 789,
129
+ "image": example2_path, # example2_image,
130
+ # "use_lora": True,
131
+ "model": "danhtran2mind/Ghibli-Stable-Diffusion-2.1-LoRA",
132
+ "base_model": "stabilityai/stable-diffusion-2-1",
133
+ "lora_rank": 64,
134
+ "lora_alpha": 0.9
135
+ }.values())]
136
+
137
+ else:
138
+ try:
139
+ example1_path = "apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning/1/result.png"
140
+ if os.path.exists(example1_path):
141
+ example1_image = Image.open(example1_path)
142
+ except Exception as e:
143
+ print(f"Failed to load example1 image: {e}")
144
+ output = [list({
145
+ "prompt": "a serene landscape in Ghibli style",
146
+ "width": 256,
147
+ "height": 512,
148
+ "steps": 50,
149
+ "cfg_scale": 3.5,
150
+ "seed": 42,
151
+ "image": example1_path, # example1_image,
152
+ # "use_lora": False,
153
+ "model": "danhtran2mind/Ghibli-Stable-Diffusion-2.1-Base-finetuning"
154
+ }.values())]
155
+
156
+ return output
157
+
158
+ def create_demo(
159
+ config_path: str = "configs/model_ckpts.yaml",
160
+ device: str = "cuda" if torch.cuda.is_available() else "cpu",
161
+ ):
162
+ model_configs = load_model_configs(config_path)
163
+
164
+ finetune_model_id = next((mid for mid, cfg in model_configs.items() if cfg.get('type') == 'full_finetuning'), None)
165
+ lora_model_id = next((mid for mid, cfg in model_configs.items() if cfg.get('type') == 'lora'), None)
166
+
167
+ if not finetune_model_id or not lora_model_id:
168
+ raise ValueError("Could not find full_finetuning or lora model IDs in the configuration file.")
169
+
170
+ finetune_config = model_configs.get(finetune_model_id, {})
171
+ finetune_local_dir = finetune_config.get('local_dir')
172
+ if finetune_local_dir and os.path.exists(finetune_local_dir) and any(os.path.isfile(os.path.join(finetune_local_dir, f)) for f in os.listdir(finetune_local_dir)):
173
+ finetune_model_path = finetune_local_dir
174
+ else:
175
+ finetune_model_path = finetune_model_id
176
+
177
+ lora_config = model_configs.get(lora_model_id, {})
178
+ lora_local_dir = lora_config.get('local_dir')
179
+ if lora_local_dir and os.path.exists(lora_local_dir) and any(os.path.isfile(os.path.join(lora_local_dir, f)) for f in os.listdir(lora_local_dir)):
180
+ lora_model_path = lora_local_dir
181
+ else:
182
+ lora_model_path = lora_model_id
183
+
184
+ base_model_id = lora_config.get('base_model_id', 'stabilityai/stable-diffusion-2-1')
185
+ base_model_config = model_configs.get(base_model_id, {})
186
+ base_local_dir = base_model_config.get('local_dir')
187
+ if base_local_dir and os.path.exists(base_local_dir) and any(os.path.isfile(os.path.join(base_local_dir, f)) for f in os.listdir(base_local_dir)):
188
+ base_model_path = base_local_dir
189
+ else:
190
+ base_model_path = base_model_id
191
+
192
+ device = torch.device(device)
193
+ dtype = torch.float16 if device.type == "cuda" else torch.float32
194
+
195
+ finetune_model_ids = [mid for mid, cfg in model_configs.items() if cfg.get('type') == 'full_finetuning']
196
+ lora_model_ids = [mid for mid, cfg in model_configs.items() if cfg.get('type') == 'lora']
197
+ base_model_ids = [model_configs[mid].get('base_model_id') for mid in model_configs if model_configs[mid].get('base_model_id')]
198
+
199
+ def generate_image(prompt, height, width, num_inference_steps, guidance_scale, seed, random_seed, use_lora, finetune_model_id, lora_model_id, base_model_id, lora_rank, lora_scale):
200
+ try:
201
+ model_configs = load_model_configs(config_path)
202
+ finetune_config = model_configs.get(finetune_model_id, {})
203
+ finetune_local_dir = finetune_config.get('local_dir')
204
+ finetune_model_path = finetune_local_dir if finetune_local_dir and os.path.exists(finetune_local_dir) and any(os.path.isfile(os.path.join(finetune_local_dir, f)) for f in os.listdir(finetune_local_dir)) else finetune_model_id
205
+
206
+ lora_config = model_configs.get(lora_model_id, {})
207
+ lora_local_dir = lora_config.get('local_dir')
208
+ lora_model_path = lora_local_dir if lora_local_dir and os.path.exists(lora_local_dir) and any(os.path.isfile(os.path.join(lora_local_dir, f)) for f in os.listdir(lora_local_dir)) else lora_model_id
209
+
210
+ base_model_config = model_configs.get(base_model_id, {})
211
+ base_local_dir = base_model_config.get('local_dir')
212
+ base_model_path = base_local_dir if base_local_dir and os.path.exists(base_local_dir) and any(os.path.isfile(os.path.join(base_local_dir, f)) for f in os.listdir(base_local_dir)) else base_model_id
213
+
214
+ if not prompt:
215
+ return None, "Prompt cannot be empty."
216
+ if height % 8 != 0 or width % 8 != 0:
217
+ return None, "Height and width must be divisible by 8."
218
+ if num_inference_steps < 1 or num_inference_steps > 100:
219
+ return None, "Number of inference steps must be between 1 and 100."
220
+ if guidance_scale < 1.0 or guidance_scale > 20.0:
221
+ return None, "Guidance scale must be between 1.0 and 20.0."
222
+ if seed < 0 or seed > 4294967295:
223
+ return None, "Seed must be between 0 and 4294967295."
224
+ if use_lora and (not lora_model_path or not os.path.exists(lora_model_path) and not lora_model_path.startswith("danhtran2mind/")):
225
+ return None, f"LoRA model path {lora_model_path} does not exist or is invalid."
226
+ if use_lora and (not base_model_path or not os.path.exists(base_model_path) and not base_model_path.startswith("stabilityai/")):
227
+ return None, f"Base model path {base_model_path} does not exist or is invalid."
228
+ if not use_lora and (not finetune_model_path or not os.path.exists(finetune_model_path) and not finetune_model_path.startswith("danhtran2mind/")):
229
+ return None, f"Fine-tuned model path {finetune_model_path} does not exist or is invalid."
230
+ if use_lora and (lora_rank < 1 or lora_rank > 128):
231
+ return None, "LoRA rank must be between 1 and 128."
232
+ if use_lora and (lora_scale < 0.0 or lora_scale > 2.0):
233
+ return None, "LoRA scale must be between 0.0 and 2.0."
234
+
235
+ batch_size = 1
236
+ if random_seed:
237
+ seed = torch.randint(0, 4294967295, (1,)).item()
238
+ generator = torch.Generator(device=device).manual_seed(int(seed))
239
+
240
+ if use_lora:
241
+ try:
242
+ pipe = StableDiffusionPipeline.from_pretrained(
243
+ base_model_path, torch_dtype=dtype, use_safetensors=True
244
+ )
245
+ pipe.load_lora_weights(lora_model_path, adapter_name="ghibli-lora", lora_scale=lora_scale)
246
+ pipe = pipe.to(device)
247
+ vae = pipe.vae
248
+ tokenizer = pipe.tokenizer
249
+ text_encoder = pipe.text_encoder
250
+ unet = pipe.unet
251
+ scheduler = PNDMScheduler.from_config(pipe.scheduler.config)
252
+ except Exception as e:
253
+ return None, f"Error loading LoRA model: {e}"
254
+ else:
255
+ try:
256
+ vae = AutoencoderKL.from_pretrained(finetune_model_path, subfolder="vae", torch_dtype=dtype).to(device)
257
+ tokenizer = CLIPTokenizer.from_pretrained(finetune_model_path, subfolder="tokenizer")
258
+ text_encoder = CLIPTextModel.from_pretrained(finetune_model_path, subfolder="text_encoder", torch_dtype=dtype).to(device)
259
+ unet = UNet2DConditionModel.from_pretrained(finetune_model_path, subfolder="unet", torch_dtype=dtype).to(device)
260
+ scheduler = PNDMScheduler.from_pretrained(finetune_model_path, subfolder="scheduler")
261
+ except Exception as e:
262
+ return None, f"Error loading fine-tuned model: {e}"
263
+
264
+ text_input = tokenizer(
265
+ [prompt], padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt"
266
+ )
267
+ with torch.no_grad():
268
+ text_embeddings = text_encoder(text_input.input_ids.to(device))[0].to(dtype=dtype)
269
+
270
+ max_length = text_input.input_ids.shape[-1]
271
+ uncond_input = tokenizer(
272
+ [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
273
+ )
274
+ with torch.no_grad():
275
+ uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0].to(dtype=dtype)
276
+
277
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
278
+
279
+ latents = torch.randn(
280
+ (batch_size, unet.config.in_channels, height // 8, width // 8),
281
+ generator=generator, dtype=dtype, device=device
282
+ )
283
+
284
+ scheduler.set_timesteps(num_inference_steps)
285
+ latents = latents * scheduler.init_noise_sigma
286
+
287
+ for t in tqdm(scheduler.timesteps, desc="Generating image"):
288
+ latent_model_input = torch.cat([latents] * 2)
289
+ latent_model_input = scheduler.scale_model_input(latent_model_input, t)
290
+
291
+ with torch.no_grad():
292
+ if device.type == "cuda":
293
+ with torch.autocast(device_type="cuda", dtype=torch.float16):
294
+ noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
295
+ else:
296
+ noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
297
+
298
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
299
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
300
+ latents = scheduler.step(noise_pred, t, latents).prev_sample
301
+
302
+ with torch.no_grad():
303
+ latents = latents / vae.config.scaling_factor
304
+ image = vae.decode(latents).sample
305
+
306
+ image = (image / 2 + 0.5).clamp(0, 1)
307
+ image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
308
+ image = (image * 255).round().astype("uint8")
309
+ pil_image = Image.fromarray(image[0])
310
+
311
+ if use_lora:
312
+ del pipe
313
+ else:
314
+ del vae, tokenizer, text_encoder, unet, scheduler
315
+ torch.cuda.empty_cache()
316
+
317
+ return pil_image, f"Generated image successfully! Seed used: {seed}"
318
+ except Exception as e:
319
+ return None, f"Failed to generate image: {e}"
320
+
321
+ def load_example_image_full_finetuning(prompt, height, width, num_inference_steps, guidance_scale,
322
+ seed, image, finetune_model_id):
323
+ try:
324
+ status = "Loaded example successfully"
325
+ return (
326
+ prompt, height, width, num_inference_steps, guidance_scale, seed,
327
+ image, finetune_model_id, status
328
+ )
329
+ except Exception as e:
330
+ print(f"DEBUG: Exception in load_example_image: {e}")
331
+ return (
332
+ prompt, height, width, num_inference_steps, guidance_scale, seed,
333
+ None, finetune_model_id,
334
+ f"Error loading example: {e}"
335
+ )
336
+
337
+ def load_example_image_lora(prompt, height, width, num_inference_steps, guidance_scale,
338
+ seed, image, lora_model_id,
339
+ base_model_id, lora_rank, lora_scale):
340
+ try:
341
+ status = "Loaded example successfully"
342
+ # Ensure base_model_id, lora_rank, and lora_scale have valid values
343
+ base_model_id = base_model_id or "stabilityai/stable-diffusion-2-1"
344
+ lora_rank = lora_rank if lora_rank is not None else 64
345
+ lora_scale = lora_scale if lora_scale is not None else 1.2
346
+
347
+ return (
348
+ prompt, height, width, num_inference_steps, guidance_scale, seed,
349
+ image, lora_model_id, base_model_id,
350
+ lora_rank, lora_scale, status
351
+ )
352
+ except Exception as e:
353
+ print(f"DEBUG: Exception in load_example_image_lora: {e}")
354
+ return (
355
+ prompt, height, width, num_inference_steps, guidance_scale, seed,
356
+ None, lora_model_id, base_model_id or "stabilityai/stable-diffusion-2-1",
357
+ lora_rank or 64, lora_scale or 1.2, f"Error loading example: {e}"
358
+ )
359
+
360
+ badges_text = r"""
361
+ <div style="text-align: left; font-size: 14px; display: flex; flex-direction: column; gap: 10px;">
362
+ <div style="display: flex; align-items: center; justify-content: left; gap: 8px;">
363
+ You can explore GitHub repository:
364
+ <a href="https://github.com/danhtran2mind/Ghibli-Stable-Diffusion-Synthesis">
365
+ <img src="https://img.shields.io/badge/GitHub-danhtran2mind%2FGhibli--Stable--Diffusion--Synthesis-blue?style=flat&logo=github" alt="GitHub Repo">
366
+ </a>. And you can explore HuggingFace Model Hub:
367
+ <a href="https://huggingface.co/spaces/danhtran2mind/Ghibli-Stable-Diffusion-2.1-Base-finetuning">
368
+ <img src="https://img.shields.io/badge/HuggingFace-danhtran2mind%2FGhibli--Stable--Diffusion--2.1--Base--finetuning-yellow?style=flat&logo=huggingface" alt="HuggingFace Space Demo">
369
+ </a>
370
+ and
371
+ <a href="https://huggingface.co/spaces/danhtran2mind/Ghibli-Stable-Diffusion-2.1-LoRA">
372
+ <img src="https://img.shields.io/badge/HuggingFace-danhtran2mind%2FGhibli--Stable--Diffusion--2.1--LoRA-yellow?style=flat&logo=huggingface" alt="HuggingFace Space Demo">
373
+ </a>
374
+ </div>
375
+ </div>
376
+ """.strip()
377
+
378
+ try:
379
+ custom_css = open("apps/gradio_app/static/styles.css", "r").read()
380
+ except FileNotFoundError:
381
+ print("Error: styles.css not found, using default styling")
382
+ custom_css = ""
383
+
384
+ examples_full_finetuning = get_examples("apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning",
385
+ use_lora=False)
386
+ examples_lora = get_examples("apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-LoRA",
387
+ use_lora=True)
388
+
389
+ with gr.Blocks(css=custom_css, theme="ocean") as demo:
390
+ gr.Markdown("## Ghibli-Style Image Generator")
391
+ with gr.Tabs():
392
+ with gr.Tab(label="Full Finetuning"):
393
+ with gr.Row():
394
+ with gr.Column(scale.=1):
395
+ gr.Markdown("### Image Generation Settings")
396
+ prompt_ft = gr.Textbox(
397
+ label="Prompt",
398
+ placeholder="e.g., 'a serene landscape in Ghibli style'",
399
+ lines=2
400
+ )
401
+ with gr.Group():
402
+ gr.Markdown("#### Image Dimensions")
403
+ with gr.Row():
404
+ width_ft = gr.Slider(
405
+ minimum=32, maximum=4096, value=512, step=8, label="Width"
406
+ )
407
+ height_ft = gr.Slider(
408
+ minimum=32, maximum=4096, value=512, step=8, label="Height"
409
+ )
410
+ with gr.Accordion("Advanced Settings", open=False):
411
+ num_inference_steps_ft = gr.Slider(
412
+ minimum=1, maximum=100, value=50, step=1, label="Inference Steps",
413
+ info="More steps, better quality, longer wait."
414
+ )
415
+ guidance_scale_ft = gr.Slider(
416
+ minimum=1.0, maximum=20.0, value=3.5, step=0.5, label="Guidance Scale",
417
+ info="Controls how closely the image follows the prompt."
418
+ )
419
+ random_seed_ft = gr.Checkbox(label="Use Random Seed", value=False)
420
+ seed_ft = gr.Slider(
421
+ minimum=0, maximum=4294967295, value=42, step=1,
422
+ label="Seed", info="Use a seed (0-4294967295) for consistent results."
423
+ )
424
+ with gr.Group():
425
+ gr.Markdown("#### Model Configuration")
426
+ finetune_model_path_ft = gr.Dropdown(
427
+ label="Fine-tuned Model", choices=finetune_model_ids,
428
+ value=finetune_model_id
429
+ )
430
+ # image_path_ft = gr.Textbox(visible=False)
431
+
432
+ with gr.Column(scale=1):
433
+ gr.Markdown("### Generated Result")
434
+ output_image_ft = gr.Image(label="Generated Image", interactive=False, height=512)
435
+ output_text_ft = gr.Textbox(label="Status", interactive=False, lines=3)
436
+
437
+ generate_btn_ft = gr.Button("Generate Image", variant="primary")
438
+ stop_btn_ft = gr.Button("Stop Generation")
439
+
440
+ gr.Markdown("### Examples for Full Finetuning")
441
+ gr.Examples(
442
+ examples=examples_full_finetuning,
443
+ inputs=[
444
+ prompt_ft, height_ft, width_ft, num_inference_steps_ft,
445
+ guidance_scale_ft, seed_ft, output_image_ft, finetune_model_path_ft
446
+ ],
447
+ outputs=[prompt_ft, height_ft, width_ft, num_inference_steps_ft,
448
+ guidance_scale_ft, seed_ft, output_image_ft, finetune_model_path_ft,
449
+ output_text_ft],
450
+ fn=load_example_image_full_finetuning,
451
+ # fn=lambda *args: load_example_image_full_finetuning(*args),
452
+ cache_examples=False,
453
+ label="Examples for Full Fine-tuning",
454
+ examples_per_page=4
455
+ )
456
+
457
+ with gr.Tab(label="LoRA"):
458
+ with gr.Row():
459
+ with gr.Column(scale=1):
460
+ gr.Markdown("### Image Generation Settings")
461
+ prompt_lora = gr.Textbox(
462
+ label="Prompt",
463
+ placeholder="e.g., 'a serene landscape in Ghibli style'",
464
+ lines=2
465
+ )
466
+ with gr.Group():
467
+ gr.Markdown("#### Image Dimensions")
468
+ with gr.Row():
469
+ width_lora = gr.Slider(
470
+ minimum=32, maximum=4096, value=512, step=8, label="Width"
471
+ )
472
+ height_lora = gr.Slider(
473
+ minimum=32, maximum=4096, value=512, step=8, label="Height"
474
+ )
475
+ with gr.Accordion("Advanced Settings", open=False):
476
+ num_inference_steps_lora = gr.Slider(
477
+ minimum=1, maximum=100, value=50, step=1, label="Inference Steps",
478
+ info="More steps, better quality, longer wait."
479
+ )
480
+ guidance_scale_lora = gr.Slider(
481
+ minimum=1.0, maximum=20.0, value=3.5, step=0.5, label="Guidance Scale",
482
+ info="Controls how closely the image follows the prompt."
483
+ )
484
+ lora_rank_lora = gr.Slider(
485
+ minimum=1, maximum=128, value=64, step=1, label="LoRA Rank",
486
+ info="Controls model complexity and memory usage."
487
+ )
488
+ lora_scale_lora = gr.Slider(
489
+ minimum=0.0, maximum=2.0, value=1.2, step=0.1, label="LoRA Scale",
490
+ info="Adjusts the influence of LoRA weights."
491
+ )
492
+ random_seed_lora = gr.Checkbox(label="Use Random Seed", value=False)
493
+ seed_lora = gr.Slider(
494
+ minimum=0, maximum=4294967295, value=42, step=1,
495
+ label="Seed", info="Use a seed (0-4294967295) for consistent results."
496
+ )
497
+ with gr.Group():
498
+ gr.Markdown("#### Model Configuration")
499
+ lora_model_path_lora = gr.Dropdown(
500
+ label="LoRA Model", choices=lora_model_ids,
501
+ value=lora_model_id
502
+ )
503
+ base_model_path_lora = gr.Dropdown(
504
+ label="Base Model", choices=base_model_ids,
505
+ value=base_model_id
506
+ )
507
+ # image_path_lora = gr.Textbox(visible=False)
508
+
509
+ with gr.Column(scale=1):
510
+ gr.Markdown("### Generated Result")
511
+ output_image_lora = gr.Image(label="Generated Image", interactive=False, height=512)
512
+ output_text_lora = gr.Textbox(label="Status", interactive=False, lines=3)
513
+
514
+ generate_btn_lora = gr.Button("Generate Image", variant="primary")
515
+ stop_btn_lora = gr.Button("Stop Generation")
516
+
517
+ gr.Markdown("### Examples for LoRA")
518
+ gr.Examples(
519
+ examples=examples_lora,
520
+ inputs=[
521
+ prompt_lora, height_lora, width_lora, num_inference_steps_lora,
522
+ guidance_scale_lora, seed_lora, output_image_lora,
523
+ lora_model_path_lora, base_model_path_lora,
524
+ lora_rank_lora, lora_scale_lora
525
+ ],
526
+ outputs=[
527
+ prompt_lora, height_lora, width_lora, num_inference_steps_lora,
528
+ guidance_scale_lora, seed_lora, output_image_lora,
529
+ lora_model_path_lora, base_model_path_lora,
530
+ lora_rank_lora, lora_scale_lora,
531
+ output_text_lora
532
+ ],
533
+ fn=load_example_image_lora,
534
+ # fn=lambda *args: load_example_image_lora(*args),
535
+ cache_examples=False,
536
+ label="Examples for LoRA",
537
+ examples_per_page=4
538
+ )
539
+
540
+ gr.Markdown(badges_text)
541
+
542
+ generate_event_ft = generate_btn_ft.click(
543
+ fn=generate_image,
544
+ inputs=[
545
+ prompt_ft, height_ft, width_ft, num_inference_steps_ft, guidance_scale_ft, seed_ft,
546
+ random_seed_ft, gr.State(value=False), finetune_model_path_ft, gr.State(value=None),
547
+ gr.State(value=None), gr.State(value=None), gr.State(value=None)
548
+ ],
549
+ outputs=[output_image_ft, output_text_ft]
550
+ )
551
+
552
+ generate_event_lora = generate_btn_lora.click(
553
+ fn=generate_image,
554
+ inputs=[
555
+ prompt_lora, height_lora, width_lora, num_inference_steps_lora, guidance_scale_lora, seed_lora,
556
+ random_seed_lora, gr.State(value=True), gr.State(value=None), lora_model_path_lora,
557
+ base_model_path_lora, lora_rank_lora, lora_scale_lora
558
+ ],
559
+ outputs=[output_image_lora, output_text_lora]
560
+ )
561
+
562
+ stop_btn_ft.click(fn=None, inputs=None, outputs=None, cancels=[generate_event_ft])
563
+ stop_btn_lora.click(fn=None, inputs=None, outputs=None, cancels=[generate_event_lora])
564
+
565
+ def cleanup():
566
+ print("DEBUG: Cleaning up resources...")
567
+ torch.cuda.empty_cache()
568
+
569
+ demo.unload(cleanup)
570
+
571
+ return demo
572
+
573
+ if __name__ == "__main__":
574
+ parser = argparse.ArgumentParser(description="Ghibli-Style Image Generator using a fine-tuned Stable Diffusion model or Stable Diffusion 2.1 with LoRA weights.")
575
+ parser.add_argument(
576
+ "--config_path",
577
+ type=str,
578
+ default="configs/model_ckpts.yaml",
579
+ help="Path to the model configuration YAML file."
580
+ )
581
+ parser.add_argument(
582
+ "--device",
583
+ type=str,
584
+ default="cuda" if torch.cuda.is_available() else "cpu",
585
+ help="Device to run the model on (e.g., 'cuda', 'cpu')."
586
+ )
587
+ parser.add_argument(
588
+ "--port",
589
+ type=int,
590
+ default=7860,
591
+ help="Port to run the Gradio app on."
592
+ )
593
+ parser.add_argument(
594
+ "--share",
595
+ action="store_true",
596
+ default=False,
597
+ help="Set to True for public sharing (Hugging Face Spaces)."
598
+ )
599
+
600
+ args = parser.parse_args()
601
+
602
+ demo = create_demo(args.config_path, args.device)
603
+ demo.launch(server_port=args.port, share=args.share)
apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning/1/config.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "prompt": "a serene landscape in Ghibli style",
3
+ "height": 256,
4
+ "width": 512,
5
+ "num_inference_steps": 50,
6
+ "guidance_scale": 3.5,
7
+ "seed": 42,
8
+ "image": "result.png",
9
+ "use_lora": false,
10
+ "finetune_model_id": "danhtran2mind/Ghibli-Stable-Diffusion-2.1-Base-finetuning"
11
+ }
apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning/1/result.png ADDED

Git LFS Details

  • SHA256: 8a955ecacd6b904093b65a7328bb1fdfc874f0866766e6f6d09bc73551a80d30
  • Pointer size: 131 Bytes
  • Size of remote file: 198 kB
apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning/2/config.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "prompt": "Donald Trump",
3
+ "height": 512,
4
+ "width": 512,
5
+ "num_inference_steps": 100,
6
+ "guidance_scale": 9,
7
+ "seed": 200,
8
+ "image": "result.png",
9
+ "use_lora": false,
10
+ "finetune_model_id": "danhtran2mind/Ghibli-Stable-Diffusion-2.1-Base-finetuning"
11
+ }
apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning/2/result.png ADDED

Git LFS Details

  • SHA256: 3e0d8bab61ede83e5e05171b93f5aa781780ee43c955bb30f95af8554587e9bd
  • Pointer size: 131 Bytes
  • Size of remote file: 232 kB
apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning/3/config.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "prompt": "a dancer in Ghibli style",
3
+ "height": 384,
4
+ "width": 192,
5
+ "num_inference_steps": 50,
6
+ "guidance_scale": 15.5,
7
+ "seed": 4223,
8
+ "image": "result.png",
9
+ "use_lora": false,
10
+ "finetune_model_id": "danhtran2mind/Ghibli-Stable-Diffusion-2.1-Base-finetuning"
11
+ }
apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning/3/result.png ADDED

Git LFS Details

  • SHA256: 5ef6e36606a3cfbb73a0a2a2a08b80c70e6405ddebb686d9db6108a3eed4ecb0
  • Pointer size: 131 Bytes
  • Size of remote file: 164 kB
apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning/4/config.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "prompt": "Ghibli style, the peace beach",
3
+ "height": 1024,
4
+ "width": 2048,
5
+ "num_inference_steps": 100,
6
+ "guidance_scale": 7.5,
7
+ "seed": 5678,
8
+ "image": "result.png",
9
+ "use_lora": false,
10
+ "finetune_model_id": "danhtran2mind/Ghibli-Stable-Diffusion-2.1-Base-finetuning"
11
+ }
apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning/4/result.png ADDED

Git LFS Details

  • SHA256: 258a57cac793da71ede5b5ecf4d752a747aee3d9022ef61947cc4e82fe8d7f51
  • Pointer size: 132 Bytes
  • Size of remote file: 3.16 MB
apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-LoRA/1/config.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "prompt": "a cat is laying on a sofa in Ghibli style",
3
+ "height": 512,
4
+ "width": 768,
5
+ "num_inference_steps": 100,
6
+ "guidance_scale": 10,
7
+ "seed": 789,
8
+ "image": "result.png",
9
+ "use_lora": true,
10
+ "lora_model_id": "danhtran2mind/Ghibli-Stable-Diffusion-2.1-LoRA",
11
+ "base_model_id": "stabilityai/stable-diffusion-2-1",
12
+ "lora_scale": 0.9
13
+ }
apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-LoRA/1/result.png ADDED

Git LFS Details

  • SHA256: 8e6861fa71cdb6b2c7d2d643de12ba6889cf251f0abfa25d21c63eb3ad2b5893
  • Pointer size: 131 Bytes
  • Size of remote file: 411 kB
apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-LoRA/2/config.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "prompt": "Ghibli style, a majestic mountain towers, casting shadows on the serene beach.",
3
+ "height": 1024,
4
+ "width": 2048,
5
+ "num_inference_steps": 75,
6
+ "guidance_scale": 14.5,
7
+ "seed": 9999,
8
+ "image": "result.png",
9
+ "use_lora": true,
10
+ "lora_model_id": "danhtran2mind/Ghibli-Stable-Diffusion-2.1-LoRA",
11
+ "base_model_id": "stabilityai/stable-diffusion-2-1",
12
+ "lora_scale": 1
13
+ }
apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-LoRA/2/result.png ADDED

Git LFS Details

  • SHA256: db4a9730beeba9eb6ed88a630f8723082e3e975b091b604512f72f45d8f034a3
  • Pointer size: 132 Bytes
  • Size of remote file: 2.59 MB
apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-LoRA/3/config.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "prompt": "In a soft, Ghibli style, Elon Musk is in a suit.",
3
+ "height": 512,
4
+ "width": 512,
5
+ "num_inference_steps": 82,
6
+ "guidance_scale": 18,
7
+ "seed": 1,
8
+ "image": "result.png",
9
+ "use_lora": true,
10
+ "lora_model_id": "danhtran2mind/Ghibli-Stable-Diffusion-2.1-LoRA",
11
+ "base_model_id": "stabilityai/stable-diffusion-2-1",
12
+ "lora_scale": 1.4
13
+ }
apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-LoRA/3/result.png ADDED

Git LFS Details

  • SHA256: 60c62a82123d5f05959f604954ccfebce0bfffb7ee17197b6b0c66fda11ae55c
  • Pointer size: 131 Bytes
  • Size of remote file: 348 kB
apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-LoRA/4/config.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "prompt": "In a Ghibli-esque world, A close-up shows a race car's soft, sun-drenched, whimsical details.",
3
+ "height": 1024,
4
+ "width": 1024,
5
+ "num_inference_steps": 42,
6
+ "guidance_scale": 20,
7
+ "seed": 1589,
8
+ "image": "result.png",
9
+ "use_lora": true,
10
+ "lora_model_id": "danhtran2mind/Ghibli-Stable-Diffusion-2.1-LoRA",
11
+ "base_model_id": "stabilityai/stable-diffusion-2-1",
12
+ "lora_scale": 0.7
13
+ }
apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-LoRA/4/result.png ADDED

Git LFS Details

  • SHA256: 79121478e4d89a673c60e0d158d278ec53dafc1062e5fe3d43ce622c8c0bf4da
  • Pointer size: 132 Bytes
  • Size of remote file: 1.14 MB
apps/gradio_app/assets/examples/default_image.png ADDED

Git LFS Details

  • SHA256: 8a955ecacd6b904093b65a7328bb1fdfc874f0866766e6f6d09bc73551a80d30
  • Pointer size: 131 Bytes
  • Size of remote file: 198 kB
apps/gradio_app/config_loader.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ import yaml
2
+
3
+ def load_model_configs(config_path: str = "configs/model_ckpts.yaml") -> dict:
4
+ with open(config_path, 'r') as f:
5
+ return {cfg['model_id']: cfg for cfg in yaml.safe_load(f)}
apps/gradio_app/example_handler.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ from typing import Union, List
4
+ from PIL import Image
5
+
6
+ def get_examples(examples_dir: Union[str, List[str]] = None, use_lora: bool = None) -> List:
7
+ directories = [examples_dir] if isinstance(examples_dir, str) else examples_dir or []
8
+ valid_dirs = [d for d in directories if os.path.isdir(d)]
9
+ if not valid_dirs:
10
+ return get_provided_examples(use_lora)
11
+
12
+ examples = []
13
+ for dir_path in valid_dirs:
14
+ for subdir in sorted(os.path.join(dir_path, d) for d in os.listdir(dir_path) if os.path.isdir(os.path.join(dir_path, d))):
15
+ config_path = os.path.join(subdir, "config.json")
16
+ image_path = os.path.join(subdir, "result.png")
17
+ if not (os.path.isfile(config_path) and os.path.isfile(image_path)):
18
+ continue
19
+
20
+ with open(config_path, 'r') as f:
21
+ config = json.load(f)
22
+
23
+ required_keys = ["prompt", "height", "width", "num_inference_steps", "guidance_scale", "seed", "image"]
24
+ if config.get("use_lora", False):
25
+ required_keys.extend(["lora_model_id", "base_model_id",
26
+ # "lora_rank",
27
+ "lora_scale"])
28
+ else:
29
+ required_keys.append("finetune_model_id")
30
+
31
+ if set(required_keys) - set(config.keys()) or config["image"] != "result.png":
32
+ continue
33
+
34
+ try:
35
+ image = Image.open(image_path)
36
+ except Exception:
37
+ continue
38
+
39
+ if use_lora is not None and config.get("use_lora", False) != use_lora:
40
+ continue
41
+
42
+ example = [config["prompt"], config["height"], config["width"], config["num_inference_steps"],
43
+ config["guidance_scale"], config["seed"], image]
44
+ example.extend([config["lora_model_id"], config["base_model_id"],
45
+ # config["lora_rank"],
46
+ config["lora_scale"]]
47
+ if config.get("use_lora", False) else [config["finetune_model_id"]])
48
+ examples.append(example)
49
+
50
+ return examples or get_provided_examples(use_lora)
51
+
52
+ def get_provided_examples(use_lora: bool = False) -> list:
53
+ example_path = f"apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-{'LoRA' if use_lora else 'Base-finetuning'}/1/result.png"
54
+ image = Image.open(example_path) if os.path.exists(example_path) else None
55
+ return [[
56
+ "a cat is laying on a sofa in Ghibli style" if use_lora else "a serene landscape in Ghibli style",
57
+ 512, 768 if use_lora else 512, 100 if use_lora else 50, 10.0 if use_lora else 3.5, 789 if use_lora else 42,
58
+ image, "danhtran2mind/Ghibli-Stable-Diffusion-2.1-LoRA" if use_lora else "danhtran2mind/Ghibli-Stable-Diffusion-2.1-Base-finetuning",
59
+ "stabilityai/stable-diffusion-2-1" if use_lora else None, 64 if use_lora else None, 0.9 if use_lora else None
60
+ ]]
apps/gradio_app/gui_components.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import os
4
+ from .example_handler import get_examples
5
+ from .image_generator import generate_image
6
+ from .project_info import intro_markdown_1, intro_markdown_2, outro_markdown_1
7
+
8
+ def load_example_image_full_finetuning(prompt, height, width, num_inference_steps, guidance_scale, seed, image, finetune_model_id):
9
+ return prompt, height, width, num_inference_steps, guidance_scale, seed, image, finetune_model_id, "Loaded example successfully"
10
+
11
+ def load_example_image_lora(prompt, height, width, num_inference_steps, guidance_scale, seed, image, lora_model_id, base_model_id, lora_scale):
12
+ return prompt, height, width, num_inference_steps, guidance_scale, seed, image, lora_model_id, base_model_id or "stabilityai/stable-diffusion-2-1", lora_scale or 1.2, "Loaded example successfully"
13
+
14
+ def create_gui(model_configs, device):
15
+ finetune_model_id = next((mid for mid, cfg in model_configs.items() if cfg.get('type') == 'full_finetuning'), None)
16
+ lora_model_id = next((mid for mid, cfg in model_configs.items() if cfg.get('type') == 'lora'), None)
17
+
18
+ if not finetune_model_id or not lora_model_id:
19
+ raise ValueError("Missing model IDs in config.")
20
+
21
+ base_model_id = model_configs[lora_model_id].get('base_model_id', 'stabilityai/stable-diffusion-2-1')
22
+ device = torch.device(device)
23
+ dtype = torch.float16 if device.type == "cuda" else torch.float32
24
+ config_path = "configs/model_ckpts.yaml"
25
+
26
+ custom_css = open("apps/gradio_app/static/styles.css", "r").read() if os.path.exists("apps/gradio_app/static/styles.css") else ""
27
+
28
+ examples_full_finetuning = get_examples("apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning", use_lora=False)
29
+ examples_lora = get_examples("apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-LoRA", use_lora=True)
30
+
31
+ with gr.Blocks(css=custom_css, theme="ocean") as demo:
32
+ gr.Markdown("# Ghibli Stable Diffusion Synthesis")
33
+ gr.HTML(intro_markdown_1)
34
+ gr.HTML(intro_markdown_2)
35
+ with gr.Tabs():
36
+ with gr.Tab(label="Full Finetuning"):
37
+ with gr.Row():
38
+ with gr.Column(scale=1):
39
+ gr.Markdown("### Image Generation Settings")
40
+ prompt_ft = gr.Textbox(label="Prompt", placeholder="e.g., 'a serene landscape in Ghibli style'", lines=2)
41
+ with gr.Group():
42
+ gr.Markdown("#### Image Dimensions")
43
+ with gr.Row():
44
+ height_ft = gr.Slider(32, 4096, 512, step=8, label="Height")
45
+ width_ft = gr.Slider(32, 4096, 512, step=8, label="Width")
46
+ with gr.Accordion("Advanced Settings", open=False):
47
+ num_inference_steps_ft = gr.Slider(1, 100, 50, step=1, label="Inference Steps")
48
+ guidance_scale_ft = gr.Slider(1.0, 20.0, 3.5, step=0.5, label="Guidance Scale")
49
+ random_seed_ft = gr.Checkbox(label="Use Random Seed")
50
+ seed_ft = gr.Slider(0, 4294967295, 42, step=1, label="Seed")
51
+ gr.Markdown("#### Model Configuration")
52
+ finetune_model_path_ft = gr.Dropdown(label="Fine-tuned Model", choices=[mid for mid, cfg in model_configs.items() if cfg.get('type') == 'full_finetuning'], value=finetune_model_id)
53
+ with gr.Column(scale=1):
54
+ gr.Markdown("### Generated Result")
55
+ output_image_ft = gr.Image(label="Generated Image", interactive=False, height=512)
56
+ output_text_ft = gr.Textbox(label="Status", interactive=False, lines=3)
57
+ generate_btn_ft = gr.Button("Generate Image", variant="primary")
58
+ stop_btn_ft = gr.Button("Stop Generation")
59
+ gr.Markdown("### Examples for Full Finetuning")
60
+ gr.Examples(examples=examples_full_finetuning, inputs=[prompt_ft, height_ft, width_ft, num_inference_steps_ft, guidance_scale_ft, seed_ft, output_image_ft, finetune_model_path_ft],
61
+ outputs=[prompt_ft, height_ft, width_ft, num_inference_steps_ft, guidance_scale_ft, seed_ft, output_image_ft, finetune_model_path_ft, output_text_ft],
62
+ fn=load_example_image_full_finetuning, cache_examples=False, examples_per_page=4)
63
+
64
+ with gr.Tab(label="LoRA"):
65
+ with gr.Row():
66
+ with gr.Column(scale=1):
67
+ gr.Markdown("### Image Generation Settings")
68
+ prompt_lora = gr.Textbox(label="Prompt", placeholder="e.g., 'a serene landscape in Ghibli style'", lines=2)
69
+ with gr.Group():
70
+ gr.Markdown("#### Image Dimensions")
71
+ with gr.Row():
72
+ height_lora = gr.Slider(32, 4096, 512, step=8, label="Height")
73
+ width_lora = gr.Slider(32, 4096, 512, step=8, label="Width")
74
+ with gr.Accordion("Advanced Settings", open=False):
75
+ num_inference_steps_lora = gr.Slider(1, 100, 50, step=1, label="Inference Steps")
76
+ guidance_scale_lora = gr.Slider(1.0, 20.0, 3.5, step=0.5, label="Guidance Scale")
77
+ lora_scale_lora = gr.Slider(0.0, 2.0, 1.2, step=0.1, label="LoRA Scale")
78
+ random_seed_lora = gr.Checkbox(label="Use Random Seed")
79
+ seed_lora = gr.Slider(0, 4294967295, 42, step=1, label="Seed")
80
+ gr.Markdown("#### Model Configuration")
81
+ lora_model_path_lora = gr.Dropdown(label="LoRA Model", choices=[mid for mid, cfg in model_configs.items() if cfg.get('type') == 'lora'], value=lora_model_id)
82
+ base_model_path_lora = gr.Dropdown(label="Base Model", choices=[model_configs[mid].get('base_model_id') for mid in model_configs if model_configs[mid].get('base_model_id')], value=base_model_id)
83
+ with gr.Column(scale=1):
84
+ gr.Markdown("### Generated Result")
85
+ output_image_lora = gr.Image(label="Generated Image", interactive=False, height=512)
86
+ output_text_lora = gr.Textbox(label="Status", interactive=False, lines=3)
87
+ generate_btn_lora = gr.Button("Generate Image", variant="primary")
88
+ stop_btn_lora = gr.Button("Stop Generation")
89
+ gr.Markdown("### Examples for LoRA")
90
+ gr.Examples(examples=examples_lora, inputs=[prompt_lora, height_lora, width_lora, num_inference_steps_lora, guidance_scale_lora, seed_lora, output_image_lora, lora_model_path_lora, base_model_path_lora, lora_scale_lora],
91
+ outputs=[prompt_lora, height_lora, width_lora, num_inference_steps_lora, guidance_scale_lora, seed_lora, output_image_lora, lora_model_path_lora, base_model_path_lora, lora_scale_lora, output_text_lora],
92
+ fn=load_example_image_lora, cache_examples=False, examples_per_page=4)
93
+
94
+ gr.HTML(outro_markdown_1)
95
+
96
+ generate_event_ft = generate_btn_ft.click(
97
+ fn=generate_image,
98
+ inputs=[prompt_ft, height_ft, width_ft,
99
+ num_inference_steps_ft, guidance_scale_ft, seed_ft,
100
+ random_seed_ft, gr.State(False), finetune_model_path_ft,
101
+ gr.State(None), gr.State(None), gr.State(None),
102
+ gr.State(config_path), gr.State(device), gr.State(dtype)],
103
+ outputs=[output_image_ft, output_text_ft]
104
+ )
105
+ generate_event_lora = generate_btn_lora.click(
106
+ fn=generate_image,
107
+ inputs=[prompt_lora, height_lora, width_lora,
108
+ num_inference_steps_lora, guidance_scale_lora, seed_lora,
109
+ random_seed_lora, gr.State(True), gr.State(None),
110
+ lora_model_path_lora, base_model_path_lora, lora_scale_lora,
111
+ gr.State(config_path), gr.State(device), gr.State(dtype)],
112
+ outputs=[output_image_lora, output_text_lora]
113
+ )
114
+
115
+ stop_btn_ft.click(fn=None, inputs=None, outputs=None, cancels=[generate_event_ft])
116
+ stop_btn_lora.click(fn=None, inputs=None, outputs=None, cancels=[generate_event_lora])
117
+
118
+ demo.unload(lambda: torch.cuda.empty_cache())
119
+
120
+ return demo
apps/gradio_app/image_generator.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import torch
4
+
5
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..',
6
+ 'src', 'ghibli_stable_diffusion_synthesis',
7
+ 'inference')))
8
+
9
+ from full_finetuning import inference_process as full_finetuning_inference
10
+ from lora import inference_process as lora_inference
11
+
12
+ def generate_image(prompt, height, width, num_inference_steps, guidance_scale, seed,
13
+ random_seed, use_lora, finetune_model_id, lora_model_id, base_model_id,
14
+ lora_scale, config_path, device, dtype):
15
+ batch_size = 1
16
+ if random_seed:
17
+ seed = torch.randint(0, 4294967295, (1,)).item()
18
+ try:
19
+ model_id = finetune_model_id
20
+ if not use_lora:
21
+ pil_image = full_finetuning_inference(
22
+ prompt=prompt,
23
+ height=height,
24
+ width=width,
25
+ num_inference_steps=num_inference_steps,
26
+ guidance_scale=guidance_scale,
27
+ batch_size=batch_size,
28
+ seed=seed,
29
+ config_path=config_path,
30
+ model_id=model_id,
31
+ device=device,
32
+ dtype=dtype
33
+ )
34
+ else:
35
+ model_id = lora_model_id
36
+ pil_image = lora_inference(
37
+ prompt=prompt,
38
+ height=height,
39
+ width=width,
40
+ num_inference_steps=num_inference_steps,
41
+ guidance_scale=guidance_scale,
42
+ batch_size=batch_size,
43
+ seed=seed,
44
+ lora_scale=lora_scale,
45
+ config_path=config_path,
46
+ model_id=model_id,
47
+ # base_model_id=base_model_id,
48
+ device=device,
49
+ dtype=dtype
50
+ )
51
+ return pil_image, f"Generated image successfully! Seed used: {seed}"
52
+ except Exception as e:
53
+ return None, f"Failed to generate image: {e}"
54
+
apps/gradio_app/old-image_generator.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ import numpy as np
4
+ from transformers import CLIPTextModel, CLIPTokenizer
5
+ from diffusers import (
6
+ AutoencoderKL, UNet2DConditionModel,
7
+ PNDMScheduler, StableDiffusionPipeline
8
+ )
9
+
10
+ from tqdm import tqdm
11
+ from .config_loader import load_model_configs
12
+
13
+ def generate_image(prompt, height, width, num_inference_steps, guidance_scale, seed,
14
+ random_seed, use_lora, finetune_model_id, lora_model_id, base_model_id,
15
+ lora_scale, config_path, device, dtype):
16
+ if not prompt or height % 8 != 0 or width % 8 != 0 or num_inference_steps not in range(1, 101) or \
17
+ guidance_scale < 1.0 or guidance_scale > 20.0 or seed < 0 or seed > 4294967295 or \
18
+ (use_lora and (lora_scale < 0.0 or lora_scale > 2.0)):
19
+ return None, "Invalid input parameters."
20
+
21
+ model_configs = load_model_configs(config_path)
22
+ finetune_model_path = model_configs.get(finetune_model_id, {}).get('local_dir', finetune_model_id)
23
+ lora_model_path = model_configs.get(lora_model_id, {}).get('local_dir', lora_model_id)
24
+ base_model_path = model_configs.get(base_model_id, {}).get('local_dir', base_model_id)
25
+
26
+ generator = torch.Generator(device=device).manual_seed(torch.randint(0, 4294967295, (1,)).item() if random_seed else int(seed))
27
+
28
+ try:
29
+ if use_lora:
30
+ # Load base pipeline
31
+ pipe = StableDiffusionPipeline.from_pretrained(base_model_path, torch_dtype=dtype, use_safetensors=True)
32
+
33
+ # Add LoRA weights with specified rank and scale
34
+ pipe.load_lora_weights(lora_model_path, adapter_name="ghibli-lora",
35
+ lora_scale=lora_scale)
36
+
37
+ pipe = pipe.to(device)
38
+ vae, tokenizer, text_encoder, unet, scheduler = pipe.vae, pipe.tokenizer, pipe.text_encoder, pipe.unet, PNDMScheduler.from_config(pipe.scheduler.config)
39
+ else:
40
+ vae = AutoencoderKL.from_pretrained(finetune_model_path, subfolder="vae", torch_dtype=dtype).to(device)
41
+ tokenizer = CLIPTokenizer.from_pretrained(finetune_model_path, subfolder="tokenizer")
42
+ text_encoder = CLIPTextModel.from_pretrained(finetune_model_path, subfolder="text_encoder", torch_dtype=dtype).to(device)
43
+ unet = UNet2DConditionModel.from_pretrained(finetune_model_path, subfolder="unet", torch_dtype=dtype).to(device)
44
+ scheduler = PNDMScheduler.from_pretrained(finetune_model_path, subfolder="scheduler")
45
+
46
+ text_input = tokenizer([prompt], padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
47
+ text_embeddings = text_encoder(text_input.input_ids.to(device))[0].to(dtype=dtype)
48
+
49
+ uncond_input = tokenizer([""] * 1, padding="max_length", max_length=text_input.input_ids.shape[-1], return_tensors="pt")
50
+ uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0].to(dtype=dtype)
51
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
52
+
53
+ latents = torch.randn((1, unet.config.in_channels, height // 8, width // 8), generator=generator, dtype=dtype, device=device)
54
+ scheduler.set_timesteps(num_inference_steps)
55
+ latents = latents * scheduler.init_noise_sigma
56
+
57
+ for t in tqdm(scheduler.timesteps, desc="Generating image"):
58
+ latent_model_input = torch.cat([latents] * 2)
59
+ latent_model_input = scheduler.scale_model_input(latent_model_input, t)
60
+ noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
61
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
62
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
63
+ latents = scheduler.step(noise_pred, t, latents).prev_sample
64
+
65
+ image = vae.decode(latents / vae.config.scaling_factor).sample
66
+ image = (image / 2 + 0.5).clamp(0, 1).detach().cpu().permute(0, 2, 3, 1).numpy()
67
+ pil_image = Image.fromarray((image[0] * 255).round().astype("uint8"))
68
+
69
+ if use_lora:
70
+ del pipe
71
+ else:
72
+ del vae, tokenizer, text_encoder, unet, scheduler
73
+ torch.cuda.empty_cache()
74
+
75
+ return pil_image, f"Generated image successfully! Seed used: {seed}"
76
+ except Exception as e:
77
+ return None, f"Failed to generate image: {e}"
apps/gradio_app/project_info.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ intro_markdown_1 = """
2
+ <h3>Create Studio Ghibli-style art with Stable Diffusion AI.</h3>
3
+ """.strip()
4
+
5
+ intro_markdown_2 = """
6
+ <div style="text-align: left; font-size: 14px; display: flex; flex-direction: column; gap: 10px;">
7
+ <div style="display: flex; align-items: center; justify-content: left; gap: 8px;">
8
+ You can explore this GitHub Source code: <a href="https://github.com/danhtran2mind/Ghibli-Stable-Diffusion-Synthesis">
9
+ <img src="https://img.shields.io/badge/GitHub-danhtran2mind%2FGhibli--Stable--Diffusion--Synthesis-blue?style=flat&logo=github" alt="GitHub Repo">
10
+ </a>
11
+ </div>
12
+ <div style="display: flex; align-items: center; justify-content: left; gap: 8px;">
13
+ And HuggingFace Model Hubs:
14
+ <a href="https://huggingface.co/danhtran2mind/Ghibli-Stable-Diffusion-2.1-Base-finetuning">
15
+ <img src="https://img.shields.io/badge/HuggingFace-danhtran2mind%2FGhibli--Stable--Diffusion--2.1--Base--finetuning-yellow?style=flat&logo=huggingface" alt="HuggingFace Model Hub">
16
+ </a>, and
17
+ <a href="https://huggingface.co/danhtran2mind/Ghibli-Stable-Diffusion-2.1-LoRA">
18
+ <img src="https://img.shields.io/badge/HuggingFace-danhtran2mind%2FGhibli--Stable--Diffusion--2.1--LoRA-yellow?style=flat&logo=huggingface" alt="HuggingFace Model Hub">
19
+ </a>
20
+ </div>
21
+ </div>
22
+ """.strip()
23
+
24
+ outro_markdown_1 = """
25
+ <div style="text-align: left; font-size: 14px; display: flex; flex-direction: column; gap: 10px;">
26
+ <div style="display: flex; align-items: center; justify-content: left; gap: 8px;">
27
+ This is the pre-trained models on our Hugging Face Model Hubs:
28
+ <a href="https://huggingface.co/stabilityai/stable-diffusion-2-1">
29
+ <img src="https://img.shields.io/badge/HuggingFace-stabilityai%2Fstable--diffusion--2--1-yellow?style=flat&logo=huggingface" alt="HuggingFace Model Hub">
30
+ </a>, and
31
+ <a href="https://huggingface.co/stabilityai/stable-diffusion-2-1-base">
32
+ <img src="https://img.shields.io/badge/HuggingFace-stabilityai%2Fstable--diffusion--2--1--base-yellow?style=flat&logo=huggingface" alt="HuggingFace Model Hub">
33
+ </a>
34
+ </div>
35
+ </div>
36
+ """.strip()
apps/gradio_app/setup_scripts.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import subprocess
2
+ import sys
3
+ import os
4
+
5
+ def run_script(script_path, args=None):
6
+ """
7
+ Run a Python script using subprocess with optional arguments and handle errors.
8
+ Returns True if successful, False otherwise.
9
+ """
10
+ if not os.path.isfile(script_path):
11
+ print(f"Script not found: {script_path}")
12
+ return False
13
+
14
+ try:
15
+ command = [sys.executable, script_path]
16
+ if args:
17
+ command.extend(args)
18
+ result = subprocess.run(
19
+ command,
20
+ check=True,
21
+ text=True,
22
+ capture_output=True
23
+ )
24
+ print(f"Successfully executed {script_path}")
25
+ print(result.stdout)
26
+ return True
27
+ except subprocess.CalledProcessError as e:
28
+ print(f"Error executing {script_path}:")
29
+ print(e.stderr)
30
+ return False
31
+ except Exception as e:
32
+ print(f"Unexpected error executing {script_path}: {str(e)}")
33
+ return False
34
+
35
+ def main():
36
+ """
37
+ Main function to execute download_ckpts.py with proper error handling.
38
+ """
39
+ scripts_dir = "scripts"
40
+ scripts = [
41
+ {
42
+ "path": os.path.join(scripts_dir, "download_ckpts.py"),
43
+ "args": [] # Empty list for args to avoid NoneType issues
44
+ },
45
+ # Uncomment and add arguments if needed for setup_third_party.py
46
+ # {
47
+ # "path": os.path.join(scripts_dir, "setup_third_party.py"),
48
+ # "args": []
49
+ # }
50
+ ]
51
+
52
+ for script in scripts:
53
+ script_path = script["path"]
54
+ args = script.get("args", []) # Safely get args with default empty list
55
+ print(f"Starting execution of {script_path}{' with args: ' + ' '.join(args) if args else ''}\n")
56
+
57
+ if not run_script(script_path, args):
58
+ print(f"Stopping execution due to error in {script_path}")
59
+ sys.exit(1)
60
+
61
+ print(f"Completed execution of {script_path}\n")
62
+
63
+ if __name__ == "__main__":
64
+ main()
apps/gradio_app/static/styles.css ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ :root {
2
+ --primary-color: #10b981; /* Updated to success color */
3
+ --primary-hover: #0a8f66; /* Darkened shade of #10b981 for hover */
4
+ --accent-color: #8a1bf2; /* New variable for the second gradient color */
5
+ --accent-hover: #6b21a8; /* Darkened shade of #8a1bf2 for hover */
6
+ --secondary-color: #64748b;
7
+ --success-color: #10b981;
8
+ --warning-color: #f59e0b;
9
+ --danger-color: #ef4444;
10
+ --border-radius: 0.5rem; /* Relative unit */
11
+ --shadow-sm: 0 0.0625rem 0.125rem 0 rgba(0, 0, 0, 0.05);
12
+ --shadow-md: 0 0.25rem 0.375rem -0.0625rem rgba(0, 0, 0, 0.1);
13
+ --shadow-lg: 0 0.625rem 0.9375rem -0.1875rem rgba(0, 0, 0, 0.1);
14
+ }
15
+
16
+ /* Container Styles */
17
+ .gradio-container {
18
+ max-width: 75rem !important; /* Relative to viewport */
19
+ margin: 0 auto !important;
20
+ padding: 1.25rem !important; /* Relative padding */
21
+ font-family: 'Segoe UI', system-ui, -apple-system, sans-serif !important;
22
+ }
23
+
24
+ /* Card/Panel Styles */
25
+ .svelte-15lo0d9, .panel {
26
+ background: var(--block-background-fill) !important;
27
+ border-radius: var(--border-radius) !important;
28
+ box-shadow: var(--shadow-md) !important;
29
+ border: 0.0625rem solid var(--border-color-primary) !important;
30
+ backdrop-filter: blur(0.625rem) !important;
31
+ }
32
+
33
+ /* Button Styles */
34
+ button.primary {
35
+ background: linear-gradient(135deg, var(--primary-color), var(--accent-color)) !important;
36
+ border: none !important;
37
+ border-radius: var(--border-radius) !important;
38
+ padding: 0.625rem 1.25rem !important; /* Relative padding */
39
+ font-weight: 600 !important;
40
+ font-size: 1rem !important; /* Relative font size */
41
+ transition: all 0.3s ease !important;
42
+ box-shadow: var(--shadow-sm) !important;
43
+ }
44
+
45
+ button.primary:hover {
46
+ background: linear-gradient(135deg, var(--primary-hover), var(--accent-hover)) !important;
47
+ transform: translateY(-0.0625rem) !important;
48
+ box-shadow: var(--shadow-md) !important;
49
+ }
50
+
51
+ button.secondary {
52
+ background: transparent !important;
53
+ border: 0.0625rem solid var(--border-color-primary) !important;
54
+ border-radius: var(--border-radius) !important;
55
+ color: var(--body-text-color) !important;
56
+ font-weight: 500 !important;
57
+ font-size: 0.875rem !important; /* Relative font size */
58
+ }
59
+
60
+ /* Slider Styles */
61
+ .slider_input_container input[type="range"][name="cowbell"] {
62
+ -webkit-appearance: none !important;
63
+ width: 100% !important;
64
+ height: 0.5rem !important; /* Relative height */
65
+ border-radius: var(--border-radius) !important;
66
+ background: linear-gradient(90deg, var(--primary-color), var(--accent-color)) !important;
67
+ outline: none !important;
68
+ }
69
+
70
+ .slider_input_container input[type="range"][name="cowbell"]::-webkit-slider-thumb {
71
+ -webkit-appearance: none !important;
72
+ width: 1rem !important; /* Relative size */
73
+ height: 1rem !important;
74
+ border-radius: 50% !important;
75
+ background: var(--accent-color) !important;
76
+ cursor: pointer !important;
77
+ box-shadow: var(--shadow-sm) !important;
78
+ border: 0.0625rem solid var(--border-color-primary) !important;
79
+ }
80
+
81
+ .slider_input_container input[type="range"][name="cowbell"]::-webkit-slider-thumb:hover {
82
+ background: var(--accent-color) !important;
83
+ box-shadow: var(--shadow-md) !important;
84
+ }
85
+
86
+ .slider_input_container input[type="range"][name="cowbell"]::-moz-range-track {
87
+ height: 0.5rem !important; /* Relative height */
88
+ border-radius: var(--border-radius) !important;
89
+ background: linear-gradient(90deg, var(--primary-color), var(--accent-color)) !important;
90
+ }
91
+
92
+ .slider_input_container input[type="range"][name="cowbell"]::-moz-range-thumb {
93
+ width: 1rem !important; /* Relative size */
94
+ height: 1rem !important;
95
+ border-radius: 50% !important;
96
+ background: var(--accent-color) !important;
97
+ cursor: pointer !important;
98
+ box-shadow: var(--shadow-sm) !important;
99
+ border: 0.0625rem solid var(--border-color-primary) !important;
100
+ }
101
+
102
+ .slider_input_container input[type="range"][name="cowbell"]::-moz-range-thumb:hover {
103
+ background: var(--accent-color) !important;
104
+ box-shadow: var(--shadow-md) !important;
105
+ }
106
+
107
+ /* Header Styles */
108
+ h1, h2, h3, h4, h5, h6 {
109
+ font-weight: 700 !important;
110
+ color: var(--body-text-color) !important;
111
+ letter-spacing: -0.02em !important;
112
+ }
113
+
114
+ h1 {
115
+ font-size: 2.5rem !important; /* Kept as is, suitable for zooming */
116
+ margin-bottom: 1rem !important;
117
+ }
118
+
119
+ h2 {
120
+ font-size: 1.75rem !important; /* Kept as is, suitable for zooming */
121
+ margin: 1.5rem 0 1rem 0 !important;
122
+ }
123
+
124
+ /* Text Styles */
125
+ p, .prose {
126
+ color: var(--body-text-color-subdued) !important;
127
+ line-height: 1.6 !important;
128
+ font-size: 1rem !important; /* Relative font size */
129
+ }
130
+
131
+ /* Alert/Notification Styles */
132
+ .alert-info {
133
+ background: linear-gradient(135deg, #dbeafe, #bfdbfe) !important;
134
+ border: 0.0625rem solid #93c5fd !important;
135
+ border-radius: var(--border-radius) !important;
136
+ color: #1e40af !important;
137
+ font-size: 0.875rem !important; /* Relative font size */
138
+ }
139
+
140
+ .alert-warning {
141
+ background: linear-gradient(135deg, #fef3c7, #fde68a) !important;
142
+ border: 0.0625rem solid #fcd34d !important;
143
+ border-radius: var(--border-radius) !important;
144
+ color: #92400e !important;
145
+ font-size: 0.875rem !important; /* Relative font size */
146
+ }
147
+
148
+ .alert-error {
149
+ background: linear-gradient(135deg, #fecaca, #fca5a5) !important;
150
+ border: 0.0625rem solid #f87171 !important;
151
+ border-radius: var(--border-radius) !important;
152
+ color: #991b1b !important;
153
+ font-size: 0.875rem !important; /* Relative font size */
154
+ }
155
+
156
+ /* Scrollbar (Webkit browsers) */
157
+ ::-webkit-scrollbar {
158
+ width: 0.5rem !important; /* Relative size */
159
+ height: 0.5rem !important;
160
+ }
161
+
162
+ ::-webkit-scrollbar-track {
163
+ background: var(--background-fill-secondary) !important;
164
+ border-radius: 0.25rem !important;
165
+ }
166
+
167
+ ::-webkit-scrollbar-thumb {
168
+ background: var(--secondary-color) !important;
169
+ border-radius: 0.25rem !important;
170
+ }
171
+
172
+ ::-webkit-scrollbar-thumb:hover {
173
+ background: var(--primary-color) !important;
174
+ }
175
+
176
+ /* Tab Styles */
177
+ .gradio-container .tabs button {
178
+ font-size: 1.0625rem !important; /* Relative font size (17px equivalent) */
179
+ font-weight: bold !important;
180
+ }
181
+
182
+ /* Dark Theme Specific Overrides */
183
+ @media (prefers-color-scheme: dark) {
184
+ :root {
185
+ --shadow-sm: 0 0.0625rem 0.125rem 0 rgba(0, 0, 0, 0.2);
186
+ --shadow-md: 0 0.25rem 0.375rem -0.0625rem rgba(0, 0, 0, 0.3);
187
+ --shadow-lg: 0 0.625rem 0.9375rem -0.1875rem rgba(0, 0, 0, 0.3);
188
+ }
189
+ }
190
+
191
+ /* Light Theme Specific Overrides */
192
+ @media (prefers-color-scheme: light) {
193
+ .gradio-container {
194
+ background: linear-gradient(135deg, #f8fafc, #f1f5f9) !important;
195
+ }
196
+ }
197
+
198
+ /* Responsive adjustments for zoom */
199
+ @media screen and (max-width: 48rem) { /* 768px equivalent */
200
+ .gradio-container {
201
+ padding: 0.625rem !important;
202
+ }
203
+ h1 {
204
+ font-size: 2rem !important;
205
+ }
206
+ h2 {
207
+ font-size: 1.5rem !important;
208
+ }
209
+ button.primary {
210
+ padding: 0.5rem 1rem !important;
211
+ font-size: 0.875rem !important;
212
+ }
213
+ }
apps/old-gradio_app.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ from pathlib import Path
4
+ import os
5
+ import gradio as gr
6
+ import torch
7
+ from PIL import Image
8
+ import numpy as np
9
+ from transformers import CLIPTextModel, CLIPTokenizer
10
+ from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler
11
+ from tqdm import tqdm
12
+ from transformers import HfArgumentParser
13
+
14
+ def get_examples(examples_dir: str = "apps/gradip_app/assets/examples/ghibli-fine-tuned-sd-2.1") -> list:
15
+ """
16
+ Load example data from the assets/examples directory.
17
+ Each example is a subdirectory containing a config.json and an image file.
18
+ Returns a list of [prompt, height, width, num_inference_steps, guidance_scale, seed, image_path].
19
+ """
20
+ # Check if the directory exists
21
+ if not os.path.exists(examples_dir) or not os.path.isdir(examples_dir):
22
+ raise ValueError(f"Directory {examples_dir} does not exist or is not a directory")
23
+
24
+ # Get list of subfolder paths (e.g., 1, 2, etc.)
25
+ all_examples_dir = [os.path.join(examples_dir, d) for d in os.listdir(examples_dir)
26
+ if os.path.isdir(os.path.join(examples_dir, d))]
27
+
28
+ ans = []
29
+ for example_dir in all_examples_dir:
30
+ config_path = os.path.join(example_dir, "config.json")
31
+ image_path = os.path.join(example_dir, "result.png")
32
+
33
+ # Check if config.json and result.png exist
34
+ if not os.path.isfile(config_path):
35
+ print(f"Warning: config.json not found in {example_dir}")
36
+ continue
37
+ if not os.path.isfile(image_path):
38
+ print(f"Warning: result.png not found in {example_dir}")
39
+ continue
40
+
41
+ try:
42
+ with open(config_path, 'r') as f:
43
+ example_dict = json.load(f)
44
+ except (json.JSONDecodeError, IOError) as e:
45
+ print(f"Error reading or parsing {config_path}: {e}")
46
+ continue
47
+
48
+ # Required keys for the config
49
+ required_keys = ["prompt", "height", "width", "num_inference_steps", "guidance_scale", "seed", "image"]
50
+ if not all(key in example_dict for key in required_keys):
51
+ print(f"Warning: Missing required keys in {config_path}")
52
+ continue
53
+
54
+ # Verify that the image key in config.json matches 'result.png'
55
+ if example_dict["image"] != "result.png":
56
+ print(f"Warning: Image key in {config_path} does not match 'result.png'")
57
+ continue
58
+
59
+ try:
60
+ example_list = [
61
+ example_dict["prompt"],
62
+ example_dict["height"],
63
+ example_dict["width"],
64
+ example_dict["num_inference_steps"],
65
+ example_dict["guidance_scale"],
66
+ example_dict["seed"],
67
+ image_path # Use verified image path
68
+ ]
69
+ ans.append(example_list)
70
+ except KeyError as e:
71
+ print(f"Error processing {config_path}: Missing key {e}")
72
+ continue
73
+
74
+ if not ans:
75
+ ans = [
76
+ ["a serene landscape in Ghibli style", 64, 64, 50, 3.5, 42, None]
77
+ ]
78
+ return ans
79
+
80
+ def create_demo(
81
+ model_name: str = "danhtran2mind/ghibli-fine-tuned-sd-2.1",
82
+ device: str = "cuda" if torch.cuda.is_available() else "cpu",
83
+ ):
84
+ # Convert device string to torch.device
85
+ device = torch.device(device)
86
+ dtype = torch.float16 if device.type == "cuda" else torch.float32
87
+
88
+ # Load models with consistent dtype
89
+ vae = AutoencoderKL.from_pretrained(model_name, subfolder="vae", torch_dtype=dtype).to(device)
90
+ tokenizer = CLIPTokenizer.from_pretrained(model_name, subfolder="tokenizer")
91
+ text_encoder = CLIPTextModel.from_pretrained(model_name, subfolder="text_encoder", torch_dtype=dtype).to(device)
92
+ unet = UNet2DConditionModel.from_pretrained(model_name, subfolder="unet", torch_dtype=dtype).to(device)
93
+ scheduler = PNDMScheduler.from_pretrained(model_name, subfolder="scheduler")
94
+
95
+ def generate_image(prompt, height, width, num_inference_steps, guidance_scale, seed, random_seed):
96
+ if not prompt:
97
+ return None, "Prompt cannot be empty."
98
+ if height % 8 != 0 or width % 8 != 0:
99
+ return None, "Height and width must be divisible by 8 (e.g., 256, 512, 1024)."
100
+ if num_inference_steps < 1 or num_inference_steps > 100:
101
+ return None, "Number of inference steps must be between 1 and 100."
102
+ if guidance_scale < 1.0 or guidance_scale > 20.0:
103
+ return None, "Guidance scale must be between 1.0 and 20.0."
104
+ if seed < 0 or seed > 4294967295:
105
+ return None, "Seed must be between 0 and 4294967295."
106
+
107
+ batch_size = 1
108
+ if random_seed:
109
+ seed = torch.randint(0, 4294967295, (1,)).item()
110
+ generator = torch.Generator(device=device).manual_seed(int(seed))
111
+
112
+ text_input = tokenizer(
113
+ [prompt], padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt"
114
+ )
115
+ with torch.no_grad():
116
+ text_embeddings = text_encoder(text_input.input_ids.to(device))[0].to(dtype=dtype)
117
+
118
+ max_length = text_input.input_ids.shape[-1]
119
+ uncond_input = tokenizer(
120
+ [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
121
+ )
122
+ with torch.no_grad():
123
+ uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0].to(dtype=dtype)
124
+
125
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
126
+
127
+ latents = torch.randn(
128
+ (batch_size, unet.config.in_channels, height // 8, width // 8),
129
+ generator=generator,
130
+ dtype=dtype,
131
+ device=device
132
+ )
133
+
134
+ scheduler.set_timesteps(num_inference_steps)
135
+ latents = latents * scheduler.init_noise_sigma
136
+
137
+ for t in tqdm(scheduler.timesteps, desc="Generating image"):
138
+ latent_model_input = torch.cat([latents] * 2)
139
+ latent_model_input = scheduler.scale_model_input(latent_model_input, t)
140
+
141
+ with torch.no_grad():
142
+ if device.type == "cuda":
143
+ with torch.autocast(device_type="cuda", dtype=torch.float16):
144
+ noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
145
+ else:
146
+ noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
147
+
148
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
149
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
150
+ latents = scheduler.step(noise_pred, t, latents).prev_sample
151
+
152
+ with torch.no_grad():
153
+ latents = latents / vae.config.scaling_factor
154
+ image = vae.decode(latents).sample
155
+
156
+ image = (image / 2 + 0.5).clamp(0, 1)
157
+ image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
158
+ image = (image * 255).round().astype("uint8")
159
+ pil_image = Image.fromarray(image[0])
160
+
161
+ return pil_image, f"Image generated successfully! Seed used: {seed}"
162
+
163
+ def load_example_image(prompt, height, width, num_inference_steps, guidance_scale, seed, image_path):
164
+ """
165
+ Load the image for the selected example and update input fields.
166
+ """
167
+ if image_path and Path(image_path).exists():
168
+ try:
169
+ image = Image.open(image_path)
170
+ return prompt, height, width, num_inference_steps, guidance_scale, seed, image, f"Loaded image: {image_path}"
171
+ except Exception as e:
172
+ return prompt, height, width, num_inference_steps, guidance_scale, seed, None, f"Error loading image: {e}"
173
+ return prompt, height, width, num_inference_steps, guidance_scale, seed, None, "No image available"
174
+
175
+ badges_text = r"""
176
+ <div style="text-align: center; display: flex; justify-content: left; gap: 5px;">
177
+ <a href="https://huggingface.co/spaces/danhtran2mind/ghibli-fine-tuned-sd-2.1"><img src="https://img.shields.io/static/v1?label=%F0%9F%A4%97%20Hugging%20Face&message=Space&color=orange"></a>
178
+ </div>
179
+ """.strip()
180
+
181
+ with gr.Blocks() as demo:
182
+ gr.Markdown("# Ghibli-Style Image Generator")
183
+ gr.Markdown(badges_text)
184
+ gr.Markdown("Generate images in Ghibli style using a fine-tuned Stable Diffusion model. Select an example below to load a pre-generated image or enter a prompt to generate a new one.")
185
+ gr.Markdown("""**Note:** For CPU inference, execution time is long (e.g., for resolution 512 × 512) with 50 inference steps, time is approximately 1700 seconds).""")
186
+
187
+ with gr.Row():
188
+ with gr.Column():
189
+ prompt = gr.Textbox(label="Prompt", placeholder="e.g., 'a serene landscape in Ghibli style'")
190
+ with gr.Row():
191
+ width = gr.Slider(32, 4096, 512, step=8, label="Generation Width")
192
+ height = gr.Slider(32, 4096, 512, step=8, label="Generation Height")
193
+ with gr.Accordion("Advanced Options", open=False):
194
+ num_inference_steps = gr.Slider(1, 100, 50, step=1, label="Number of Inference Steps")
195
+ guidance_scale = gr.Slider(1.0, 20.0, 3.5, step=0.5, label="Guidance Scale")
196
+ seed = gr.Number(42, label="Seed (0 to 4294967295)")
197
+ random_seed = gr.Checkbox(label="Use Random Seed", value=False)
198
+ generate_btn = gr.Button("Generate Image")
199
+
200
+ with gr.Column():
201
+ output_image = gr.Image(label="Generated Image")
202
+ output_text = gr.Textbox(label="Status")
203
+
204
+ examples = get_examples("assets/examples/ghibli-fine-tuned-sd-2.1")
205
+ gr.Examples(
206
+ examples=examples,
207
+ inputs=[prompt, height, width, num_inference_steps, guidance_scale, seed, output_image],
208
+ outputs=[prompt, height, width, num_inference_steps, guidance_scale, seed, output_image, output_text],
209
+ fn=load_example_image,
210
+ cache_examples=False
211
+ )
212
+
213
+ generate_btn.click(
214
+ fn=generate_image,
215
+ inputs=[prompt, height, width, num_inference_steps, guidance_scale, seed, random_seed],
216
+ outputs=[output_image, output_text]
217
+ )
218
+
219
+ return demo
220
+
221
+ if __name__ == "__main__":
222
+ parser = argparse.ArgumentParser(description="Ghibli-Style Image Generator using a fine-tuned Stable Diffusion model.")
223
+ parser.add_argument(
224
+ "--local_model",
225
+ action="store_true",
226
+ default=True,
227
+ help="Use local model path instead of Hugging Face model."
228
+ )
229
+ parser.add_argument(
230
+ "--model_name",
231
+ type=str,
232
+ default="danhtran2mind/ghibli-fine-tuned-sd-2.1",
233
+ help="Model name or path for the fine-tuned Stable Diffusion model."
234
+ )
235
+ parser.add_argument(
236
+ "--device",
237
+ type=str,
238
+ default="cuda" if torch.cuda.is_available() else "cpu",
239
+ help="Device to run the model on (e.g., 'cuda', 'cpu')."
240
+ )
241
+ parser.add_argument(
242
+ "--port",
243
+ type=int,
244
+ default=7860,
245
+ help="Port to run the Gradio app on."
246
+ )
247
+ parser.add_argument(
248
+ "--share",
249
+ action="store_true",
250
+ default=False,
251
+ help="Set to True for public sharing (Hugging Face Spaces)."
252
+ )
253
+
254
+ args = parser.parse_args()
255
+
256
+ # Set model_name based on local_model flag
257
+ if args.local_model:
258
+ args.model_name = "./checkpoints/ghibli-fine-tuned-sd-2.1"
259
+
260
+ demo = create_demo(args.model_name, args.device)
261
+ demo.launch(server_port=args.port, share=args.share)
apps/old2-gradio_app.py ADDED
@@ -0,0 +1,376 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ from pathlib import Path
4
+ import os
5
+ import gradio as gr
6
+ import torch
7
+ from PIL import Image
8
+ import numpy as np
9
+ from transformers import CLIPTextModel, CLIPTokenizer
10
+ from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler, StableDiffusionPipeline
11
+ from tqdm import tqdm
12
+ import yaml
13
+
14
+ def load_model_configs(config_path: str = "configs/model_ckpts.yaml") -> dict:
15
+ """
16
+ Load model configurations from a YAML file.
17
+ Returns a dictionary with model IDs and their details.
18
+ """
19
+ try:
20
+ with open(config_path, 'r') as f:
21
+ configs = yaml.safe_load(f)
22
+ return {cfg['model_id']: cfg for cfg in configs}
23
+ except (IOError, yaml.YAMLError) as e:
24
+ raise ValueError(f"Error loading {config_path}: {e}")
25
+
26
+ def get_examples(examples_dir: str = "apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning") -> list:
27
+ """
28
+ Load example data from the assets/examples directory.
29
+ Each example is a subdirectory containing a config.json and an image file.
30
+ Returns a list of [prompt, height, width, num_inference_steps, guidance_scale, seed, image_path, use_lora, finetune_model_path, lora_model_path, base_model_path, lora_rank, lora_scale].
31
+ """
32
+ if not os.path.exists(examples_dir) or not os.path.isdir(examples_dir):
33
+ raise ValueError(f"Directory {examples_dir} does not exist or is not a directory")
34
+
35
+ all_examples_dir = [os.path.join(examples_dir, d) for d in os.listdir(examples_dir)
36
+ if os.path.isdir(os.path.join(examples_dir, d))]
37
+
38
+ ans = []
39
+ for example_dir in all_examples_dir:
40
+ config_path = os.path.join(example_dir, "config.json")
41
+ image_path = os.path.join(example_dir, "result.png")
42
+
43
+ if not os.path.isfile(config_path):
44
+ print(f"Warning: config.json not found in {example_dir}")
45
+ continue
46
+ if not os.path.isfile(image_path):
47
+ print(f"Warning: result.png not found in {example_dir}")
48
+ continue
49
+
50
+ try:
51
+ with open(config_path, 'r') as f:
52
+ example_dict = json.load(f)
53
+ except (json.JSONDecodeError, IOError) as e:
54
+ print(f"Error reading or parsing {config_path}: {e}")
55
+ continue
56
+
57
+ required_keys = ["prompt", "height", "width", "num_inference_steps", "guidance_scale", "seed", "image"]
58
+ if not all(key in example_dict for key in required_keys):
59
+ print(f"Warning: Missing required keys in {config_path}")
60
+ continue
61
+
62
+ if example_dict["image"] != "result.png":
63
+ print(f"Warning: Image key in {config_path} does not match 'result.png'")
64
+ continue
65
+
66
+ try:
67
+ example_list = [
68
+ example_dict["prompt"],
69
+ example_dict["height"],
70
+ example_dict["width"],
71
+ example_dict["num_inference_steps"],
72
+ example_dict["guidance_scale"],
73
+ example_dict["seed"],
74
+ image_path,
75
+ example_dict.get("use_lora", False),
76
+ example_dict.get("finetune_model_path", "danhtran2mind/Ghibli-Stable-Diffusion-2.1-Base-finetuning"),
77
+ example_dict.get("lora_model_path", "danhtran2mind/Ghibli-Stable-Diffusion-2.1-LoRA"),
78
+ example_dict.get("base_model_path", "stabilityai/stable-diffusion-2-1"),
79
+ example_dict.get("lora_rank", 64),
80
+ example_dict.get("lora_scale", 1.2)
81
+ ]
82
+ ans.append(example_list)
83
+ except KeyError as e:
84
+ print(f"Error processing {config_path}: Missing key {e}")
85
+ continue
86
+
87
+ if not ans:
88
+ model_configs = load_model_configs("configs/model_ckpts.yaml")
89
+ finetune_model_id = "danhtran2mind/Ghibli-Stable-Diffusion-2.1-Base-finetuning"
90
+ lora_model_id = "danhtran2mind/Ghibli-Stable-Diffusion-2.1-LoRA"
91
+ base_model_id = model_configs[lora_model_id]['base_model_id'] if lora_model_id in model_configs else "stabilityai/stable-diffusion-2-1"
92
+ ans = [
93
+ ["a serene landscape in Ghibli style", 512, 512, 50, 3.5, 42, None, False,
94
+ model_configs.get(finetune_model_id, {}).get('local_dir', finetune_model_id),
95
+ model_configs.get(lora_model_id, {}).get('local_dir', lora_model_id),
96
+ base_model_id, 64, 1.2]
97
+ ]
98
+ return ans
99
+
100
+ def create_demo(
101
+ config_path: str = "configs/model_ckpts.yaml",
102
+ device: str = "cuda" if torch.cuda.is_available() else "cpu",
103
+ ):
104
+ # Load model configurations
105
+ model_configs = load_model_configs(config_path)
106
+ finetune_model_id = "danhtran2mind/Ghibli-Stable-Diffusion-2.1-Base-finetuning"
107
+ lora_model_id = "danhtran2mind/Ghibli-Stable-Diffusion-2.1-LoRA"
108
+ finetune_model_path = model_configs[finetune_model_id]['local_dir'] if model_configs[finetune_model_id]['platform'] == "Local" else finetune_model_id
109
+ lora_model_path = model_configs[lora_model_id]['local_dir'] if model_configs[lora_model_id]['platform'] == "Local" else lora_model_id
110
+ base_model_path = model_configs[lora_model_id]['base_model_id']
111
+
112
+ # Convert device string to torch.device
113
+ device = torch.device(device)
114
+ dtype = torch.float16 if device.type == "cuda" else torch.float32
115
+
116
+ # Extract model IDs for dropdown choices based on type
117
+ finetune_model_ids = [mid for mid, cfg in model_configs.items() if cfg.get('type') == 'full-finetuning']
118
+ lora_model_ids = [mid for mid, cfg in model_configs.items() if cfg.get('type') == 'lora']
119
+ base_model_ids = [model_configs[mid]['base_model_id'] for mid in model_configs if 'base_model_id' in model_configs[mid]]
120
+
121
+ def update_model_path_visibility(use_lora):
122
+ """
123
+ Update visibility of model path dropdowns based on use_lora checkbox.
124
+ """
125
+ if use_lora:
126
+ return gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)
127
+ return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)
128
+
129
+ def generate_image(prompt, height, width, num_inference_steps, guidance_scale, seed, random_seed, use_lora, finetune_model_path, lora_model_path, base_model_path, lora_rank, lora_scale):
130
+ if not prompt:
131
+ return None, "Prompt cannot be empty."
132
+ if height % 8 != 0 or width % 8 != 0:
133
+ return None, "Height and width must be divisible by 8 (e.g., 256, 512, 1024)."
134
+ if num_inference_steps < 1 or num_inference_steps > 100:
135
+ return None, "Number of inference steps must be between 1 and 100."
136
+ if guidance_scale < 1.0 or guidance_scale > 20.0:
137
+ return None, "Guidance scale must be between 1.0 and 20.0."
138
+ if seed < 0 or seed > 4294967295:
139
+ return None, "Seed must be between 0 and 4294967295."
140
+ if use_lora and (not lora_model_path or not os.path.exists(lora_model_path) and not lora_model_path.startswith("danhtran2mind/")):
141
+ return None, f"LoRA model path {lora_model_path} does not exist or is invalid."
142
+ if use_lora and (not base_model_path or not os.path.exists(base_model_path) and not base_model_path.startswith("stabilityai/")):
143
+ return None, f"Base model path {base_model_path} does not exist or is invalid."
144
+ if not use_lora and (not finetune_model_path or not os.path.exists(finetune_model_path) and not finetune_model_path.startswith("danhtran2mind/")):
145
+ return None, f"Fine-tuned model path {finetune_model_path} does not exist or is invalid."
146
+ if use_lora and (lora_rank < 1 or lora_rank > 128):
147
+ return None, "LoRA rank must be between 1 and 128."
148
+ if use_lora and (lora_scale < 0.0 or lora_scale > 2.0):
149
+ return None, "LoRA scale must be between 0.0 and 2.0."
150
+
151
+ batch_size = 1
152
+ if random_seed:
153
+ seed = torch.randint(0, 4294967295, (1,)).item()
154
+ generator = torch.Generator(device=device).manual_seed(int(seed))
155
+
156
+ # Load models based on use_lora
157
+ if use_lora:
158
+ try:
159
+ pipe = StableDiffusionPipeline.from_pretrained(
160
+ base_model_path,
161
+ torch_dtype=dtype,
162
+ use_safetensors=True
163
+ )
164
+ pipe.load_lora_weights(lora_model_path, adapter_name="ghibli-lora", lora_scale=lora_scale)
165
+ pipe = pipe.to(device)
166
+ vae = pipe.vae
167
+ tokenizer = pipe.tokenizer
168
+ text_encoder = pipe.text_encoder
169
+ unet = pipe.unet
170
+ scheduler = PNDMScheduler.from_config(pipe.scheduler.config)
171
+ except Exception as e:
172
+ return None, f"Error loading LoRA model from {lora_model_path} or base model from {base_model_path}: {e}"
173
+ else:
174
+ try:
175
+ vae = AutoencoderKL.from_pretrained(finetune_model_path, subfolder="vae", torch_dtype=dtype).to(device)
176
+ tokenizer = CLIPTokenizer.from_pretrained(finetune_model_path, subfolder="tokenizer")
177
+ text_encoder = CLIPTextModel.from_pretrained(finetune_model_path, subfolder="text_encoder", torch_dtype=dtype).to(device)
178
+ unet = UNet2DConditionModel.from_pretrained(finetune_model_path, subfolder="unet", torch_dtype=dtype).to(device)
179
+ scheduler = PNDMScheduler.from_pretrained(finetune_model_path, subfolder="scheduler")
180
+ except Exception as e:
181
+ return None, f"Error loading fine-tuned model from {finetune_model_path}: {e}"
182
+
183
+ text_input = tokenizer(
184
+ [prompt], padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt"
185
+ )
186
+ with torch.no_grad():
187
+ text_embeddings = text_encoder(text_input.input_ids.to(device))[0].to(dtype=dtype)
188
+
189
+ max_length = text_input.input_ids.shape[-1]
190
+ uncond_input = tokenizer(
191
+ [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
192
+ )
193
+ with torch.no_grad():
194
+ uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0].to(dtype=dtype)
195
+
196
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
197
+
198
+ latents = torch.randn(
199
+ (batch_size, unet.config.in_channels, height // 8, width // 8),
200
+ generator=generator,
201
+ dtype=dtype,
202
+ device=device
203
+ )
204
+
205
+ scheduler.set_timesteps(num_inference_steps)
206
+ latents = latents * scheduler.init_noise_sigma
207
+
208
+ for t in tqdm(scheduler.timesteps, desc="Generating image"):
209
+ latent_model_input = torch.cat([latents] * 2)
210
+ latent_model_input = scheduler.scale_model_input(latent_model_input, t)
211
+
212
+ with torch.no_grad():
213
+ if device.type == "cuda":
214
+ with torch.autocast(device_type="cuda", dtype=torch.float16):
215
+ noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
216
+ else:
217
+ noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
218
+
219
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
220
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
221
+ latents = scheduler.step(noise_pred, t, latents).prev_sample
222
+
223
+ with torch.no_grad():
224
+ latents = latents / vae.config.scaling_factor
225
+ image = vae.decode(latents).sample
226
+
227
+ image = (image / 2 + 0.5).clamp(0, 1)
228
+ image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
229
+ image = (image * 255).round().astype("uint8")
230
+ pil_image = Image.fromarray(image[0])
231
+
232
+ return pil_image, f"Image generated successfully! Seed used: {seed}"
233
+
234
+ def load_example_image(prompt, height, width, num_inference_steps, guidance_scale, seed, image_path, use_lora, finetune_model_path, lora_model_path, base_model_path, lora_rank, lora_scale):
235
+ """
236
+ Load the image for the selected example and update input fields.
237
+ """
238
+ if image_path and Path(image_path).exists():
239
+ try:
240
+ image = Image.open(image_path)
241
+ return (
242
+ prompt, height, width, num_inference_steps, guidance_scale, seed, image,
243
+ use_lora, finetune_model_path, lora_model_path, base_model_path, lora_rank, lora_scale,
244
+ f"Loaded image: {image_path}"
245
+ )
246
+ except Exception as e:
247
+ return (
248
+ prompt, height, width, num_inference_steps, guidance_scale, seed, None,
249
+ use_lora, finetune_model_path, lora_model_path, base_model_path, lora_rank, lora_scale,
250
+ f"Error loading image: {e}"
251
+ )
252
+ return (
253
+ prompt, height, width, num_inference_steps, guidance_scale, seed, None,
254
+ use_lora, finetune_model_path, lora_model_path, base_model_path, lora_rank, lora_scale,
255
+ "No image available"
256
+ )
257
+
258
+ badges_text = r"""
259
+ <div style="text-align: left; font-size: 14px; display: flex; flex-direction: column; gap: 10px;">
260
+ <div style="display: flex; align-items: center; justify-content: left; gap: 8px;">
261
+ You can explore GitHub repository:
262
+ <a href="https://github.com/danhtran2mind/Ghibli-Stable-Diffusion-Synthesis">
263
+ <img src="https://img.shields.io/badge/GitHub-danhtran2mind%2FGhibli--Stable--Diffusion--Synthesis-blue?style=flat&logo=github" alt="GitHub Repo">
264
+ </a>.
265
+ </div>
266
+ <div style="display: flex; align-items: center; justify-content: left; gap: 8px;">
267
+ And you can explore HuggingFace Model Hub:
268
+ <a href="https://huggingface.co/spaces/danhtran2mind/Ghibli-Stable-Diffusion-2.1-Base-finetuning">
269
+ <img src="https://img.shields.io/badge/HuggingFace-danhtran2mind%2FGhibli--Stable--Diffusion--2.1--Base--finetuning-yellow?style=flat&logo=huggingface" alt="HuggingFace Space Demo">
270
+ </a>
271
+ and
272
+ <a href="https://huggingface.co/spaces/danhtran2mind/Ghibli-Stable-Diffusion-2.1-LoRA">
273
+ <img src="https://img.shields.io/badge/HuggingFace-danhtran2mind%2FGhibli--Stable--Diffusion--2.1--LoRA-yellow?style=flat&logo=huggingface" alt="HuggingFace Space Demo">
274
+ </a>
275
+ </div>
276
+ </div>
277
+ """.strip()
278
+
279
+ with gr.Blocks() as demo:
280
+ gr.Markdown("# Ghibli-Style Image Generator")
281
+ gr.Markdown(badges_text)
282
+ gr.Markdown("Generate images in Ghibli style using a fine-tuned Stable Diffusion model or Stable Diffusion 2.1 with LoRA weights. Select an example below to load a pre-generated image or enter a prompt to generate a new one.")
283
+ gr.Markdown("""**Note:** For CPU inference, execution time is long (e.g., for resolution 512 × 512 with 50 inference steps, time is approximately 1700 seconds).""")
284
+
285
+ with gr.Row():
286
+ with gr.Column():
287
+ prompt = gr.Textbox(label="Prompt", placeholder="e.g., 'a serene landscape in Ghibli style'")
288
+ with gr.Row():
289
+ width = gr.Slider(32, 4096, 512, step=8, label="Generation Width")
290
+ height = gr.Slider(32, 4096, 512, step=8, label="Generation Height")
291
+ with gr.Accordion("Advanced Options", open=False):
292
+ num_inference_steps = gr.Slider(1, 100, 50, step=1, label="Number of Inference Steps")
293
+ guidance_scale = gr.Slider(1.0, 20.0, 3.5, step=0.5, label="Guidance Scale")
294
+ seed = gr.Number(42, label="Seed (0 to 4294967295)")
295
+ random_seed = gr.Checkbox(label="Use Random Seed", value=False)
296
+ use_lora = gr.Checkbox(label="Use LoRA Weights", value=False)
297
+ finetune_model_path = gr.Dropdown(
298
+ label="Fine-tuned Model Path",
299
+ choices=finetune_model_ids,
300
+ value=finetune_model_id,
301
+ visible=not use_lora.value
302
+ )
303
+ lora_model_path = gr.Dropdown(
304
+ label="LoRA Model Path",
305
+ choices=lora_model_ids,
306
+ value=lora_model_id,
307
+ visible=use_lora.value
308
+ )
309
+ base_model_path = gr.Dropdown(
310
+ label="Base Model Path",
311
+ choices=base_model_ids,
312
+ value=base_model_path,
313
+ visible=use_lora.value
314
+ )
315
+ lora_rank = gr.Slider(1, 128, 64, step=1, label="LoRA Rank", visible=use_lora.value)
316
+ lora_scale = gr.Slider(0.0, 2.0, 1.2, step=0.1, label="LoRA Scale", visible=use_lora.value)
317
+ generate_btn = gr.Button("Generate Image")
318
+
319
+ with gr.Column():
320
+ output_image = gr.Image(label="Generated Image")
321
+ output_text = gr.Textbox(label="Status")
322
+
323
+ examples = get_examples("assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning")
324
+ gr.Examples(
325
+ examples=examples,
326
+ inputs=[prompt, height, width, num_inference_steps, guidance_scale, seed, output_image, use_lora, finetune_model_path, lora_model_path, base_model_path, lora_rank, lora_scale],
327
+ outputs=[prompt, height, width, num_inference_steps, guidance_scale, seed, output_image, use_lora, finetune_model_path, lora_model_path, base_model_path, lora_rank, lora_scale, output_text],
328
+ fn=load_example_image,
329
+ cache_examples=False
330
+ )
331
+
332
+ use_lora.change(
333
+ fn=update_model_path_visibility,
334
+ inputs=use_lora,
335
+ outputs=[lora_model_path, base_model_path, finetune_model_path]
336
+ )
337
+
338
+ generate_btn.click(
339
+ fn=generate_image,
340
+ inputs=[prompt, height, width, num_inference_steps, guidance_scale, seed, random_seed, use_lora, finetune_model_path, lora_model_path, base_model_path, lora_rank, lora_scale],
341
+ outputs=[output_image, output_text]
342
+ )
343
+
344
+ return demo
345
+
346
+ if __name__ == "__main__":
347
+ parser = argparse.ArgumentParser(description="Ghibli-Style Image Generator using a fine-tuned Stable Diffusion model or Stable Diffusion 2.1 with LoRA weights.")
348
+ parser.add_argument(
349
+ "--config_path",
350
+ type=str,
351
+ default="configs/model_ckpts.yaml",
352
+ help="Path to the model configuration YAML file."
353
+ )
354
+ parser.add_argument(
355
+ "--device",
356
+ type=str,
357
+ default="cuda" if torch.cuda.is_available() else "cpu",
358
+ help="Device to run the model on (e.g., 'cuda', 'cpu')."
359
+ )
360
+ parser.add_argument(
361
+ "--port",
362
+ type=int,
363
+ default=7860,
364
+ help="Port to run the Gradio app on."
365
+ )
366
+ parser.add_argument(
367
+ "--share",
368
+ action="store_true",
369
+ default=False,
370
+ help="Set to True for public sharing (Hugging Face Spaces)."
371
+ )
372
+
373
+ args = parser.parse_args()
374
+
375
+ demo = create_demo(args.config_path, args.device)
376
+ demo.launch(server_port=args.port, share=args.share)
apps/old3-gradio_app.py ADDED
@@ -0,0 +1,438 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ from pathlib import Path
4
+ import os
5
+ import gradio as gr
6
+ import torch
7
+ from PIL import Image
8
+ import numpy as np
9
+ from transformers import CLIPTextModel, CLIPTokenizer
10
+ from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler, StableDiffusionPipeline
11
+ from tqdm import tqdm
12
+ import yaml
13
+
14
+ def load_model_configs(config_path: str = "configs/model_ckpts.yaml") -> dict:
15
+ """
16
+ Load model configurations from a YAML file.
17
+ Returns a dictionary with model IDs and their details.
18
+ """
19
+ try:
20
+ with open(config_path, 'r') as f:
21
+ configs = yaml.safe_load(f)
22
+ return {cfg['model_id']: cfg for cfg in configs}
23
+ except (IOError, yaml.YAMLError) as e:
24
+ raise ValueError(f"Error loading {config_path}: {e}")
25
+
26
+ def get_examples(examples_dir: str = "apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning") -> list:
27
+ """
28
+ Load example data from the assets/examples directory.
29
+ Each example is a subdirectory containing a config.json and an image file.
30
+ Returns a list of [prompt, height, width, num_inference_steps, guidance_scale, seed, image_path, use_lora, finetune_model_id, lora_model_id, base_model_id, lora_rank, lora_scale].
31
+ """
32
+ if not os.path.exists(examples_dir) or not os.path.isdir(examples_dir):
33
+ raise ValueError(f"Directory {examples_dir} does not exist or is not a directory")
34
+
35
+ all_examples_dir = [os.path.join(examples_dir, d) for d in os.listdir(examples_dir)
36
+ if os.path.isdir(os.path.join(examples_dir, d))]
37
+
38
+ ans = []
39
+ for example_dir in all_examples_dir:
40
+ config_path = os.path.join(example_dir, "config.json")
41
+ image_path = os.path.join(example_dir, "result.png")
42
+
43
+ if not os.path.isfile(config_path):
44
+ print(f"Warning: config.json not found in {example_dir}")
45
+ continue
46
+ if not os.path.isfile(image_path):
47
+ print(f"Warning: result.png not found in {example_dir}")
48
+ continue
49
+
50
+ try:
51
+ with open(config_path, 'r') as f:
52
+ example_dict = json.load(f)
53
+ except (json.JSONDecodeError, IOError) as e:
54
+ print(f"Error reading or parsing {config_path}: {e}")
55
+ continue
56
+
57
+ required_keys = ["prompt", "height", "width", "num_inference_steps", "guidance_scale", "seed", "image"]
58
+ if not all(key in example_dict for key in required_keys):
59
+ print(f"Warning: Missing required keys in {config_path}")
60
+ continue
61
+
62
+ if example_dict["image"] != "result.png":
63
+ print(f"Warning: Image key in {config_path} does not match 'result.png'")
64
+ continue
65
+
66
+ try:
67
+ model_configs = load_model_configs("configs/model_ckpts.yaml")
68
+ finetune_model_id = next((mid for mid, cfg in model_configs.items() if cfg.get('type') == 'full_finetuning'), None)
69
+ lora_model_id = next((mid for mid, cfg in model_configs.items() if cfg.get('type') == 'lora'), None)
70
+
71
+ example_list = [
72
+ example_dict["prompt"],
73
+ example_dict["height"],
74
+ example_dict["width"],
75
+ example_dict["num_inference_steps"],
76
+ example_dict["guidance_scale"],
77
+ example_dict["seed"],
78
+ image_path,
79
+ example_dict.get("use_lora", False),
80
+ finetune_model_id if finetune_model_id else "stabilityai/stable-diffusion-2-1-base",
81
+ lora_model_id if lora_model_id else "stabilityai/stable-diffusion-2-1",
82
+ model_configs.get(lora_model_id, {}).get('base_model_id', "stabilityai/stable-diffusion-2-1") if lora_model_id else "stabilityai/stable-diffusion-2-1",
83
+ example_dict.get("lora_rank", 64),
84
+ example_dict.get("lora_scale", 1.2)
85
+ ]
86
+ ans.append(example_list)
87
+ except KeyError as e:
88
+ print(f"Error processing {config_path}: Missing key {e}")
89
+ continue
90
+
91
+ if not ans:
92
+ model_configs = load_model_configs("configs/model_ckpts.yaml")
93
+ finetune_model_id = next((mid for mid, cfg in model_configs.items() if cfg.get('type') == 'full_finetuning'), "stabilityai/stable-diffusion-2-1-base")
94
+ lora_model_id = next((mid for mid, cfg in model_configs.items() if cfg.get('type') == 'lora'), "stabilityai/stable-diffusion-2-1")
95
+ base_model_id = model_configs.get(lora_model_id, {}).get('base_model_id', "stabilityai/stable-diffusion-2-1")
96
+
97
+ ans = [
98
+ ["a serene landscape in Ghibli style", 512, 512, 50, 3.5, 42, None, False,
99
+ finetune_model_id,
100
+ lora_model_id,
101
+ base_model_id, 64, 1.2]
102
+ ]
103
+ return ans
104
+
105
+ def create_demo(
106
+ config_path: str = "configs/model_ckpts.yaml",
107
+ device: str = "cuda" if torch.cuda.is_available() else "cpu",
108
+ ):
109
+ # Load model configurations
110
+ model_configs = load_model_configs(config_path)
111
+
112
+ # Load model IDs from YAML
113
+ finetune_model_id = next((mid for mid, cfg in model_configs.items() if cfg.get('type') == 'full_finetuning'), None)
114
+ lora_model_id = next((mid for mid, cfg in model_configs.items() if cfg.get('type') == 'lora'), None)
115
+
116
+ if not finetune_model_id or not lora_model_id:
117
+ raise ValueError("Could not find full_finetuning or lora model IDs in the configuration file.")
118
+
119
+ # Determine finetune model path
120
+ finetune_config = model_configs.get(finetune_model_id, {})
121
+ finetune_local_dir = finetune_config.get('local_dir')
122
+ if finetune_local_dir and os.path.exists(finetune_local_dir) and any(os.path.isfile(os.path.join(finetune_local_dir, f)) for f in os.listdir(finetune_local_dir)):
123
+ finetune_model_path = finetune_local_dir
124
+ else:
125
+ print(f"Local model directory for fine-tuned model '{finetune_model_id}' does not exist or is empty at '{finetune_local_dir}'. Falling back to model ID.")
126
+ finetune_model_path = finetune_model_id
127
+
128
+ # Determine LoRA model path
129
+ lora_config = model_configs.get(lora_model_id, {})
130
+ lora_local_dir = lora_config.get('local_dir')
131
+ if lora_local_dir and os.path.exists(lora_local_dir) and any(os.path.isfile(os.path.join(lora_local_dir, f)) for f in os.listdir(lora_local_dir)):
132
+ lora_model_path = lora_local_dir
133
+ else:
134
+ print(f"Local model directory for LoRA model '{lora_model_id}' does not exist or is empty at '{lora_local_dir}'. Falling back to model ID.")
135
+ lora_model_path = lora_model_id
136
+
137
+ # Determine base model path
138
+ base_model_id = lora_config.get('base_model_id', 'stabilityai/stable-diffusion-2-1')
139
+ base_model_config = model_configs.get(base_model_id, {})
140
+ base_local_dir = base_model_config.get('local_dir')
141
+ if base_local_dir and os.path.exists(base_local_dir) and any(os.path.isfile(os.path.join(base_local_dir, f)) for f in os.listdir(base_local_dir)):
142
+ base_model_path = base_local_dir
143
+ else:
144
+ print(f"Local model directory for base model '{base_model_id}' does not exist or is empty at '{base_local_dir}'. Falling back to model ID.")
145
+ base_model_path = base_model_id
146
+
147
+ # Convert device string to torch.device
148
+ device = torch.device(device)
149
+ dtype = torch.float16 if device.type == "cuda" else torch.float32
150
+
151
+ # Extract model IDs for dropdown choices based on type
152
+ finetune_model_ids = [mid for mid, cfg in model_configs.items() if cfg.get('type') == 'full_finetuning']
153
+ lora_model_ids = [mid for mid, cfg in model_configs.items() if cfg.get('type') == 'lora']
154
+ base_model_ids = [model_configs[mid]['base_model_id'] for mid in model_configs if 'base_model_id' in model_configs[mid]]
155
+
156
+ def update_model_path_visibility(use_lora):
157
+ """
158
+ Update visibility of model path dropdowns and LoRA sliders based on use_lora checkbox.
159
+ """
160
+ if use_lora:
161
+ return gr.update(visible=True), gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=True)
162
+ return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)
163
+
164
+ def generate_image(prompt, height, width, num_inference_steps, guidance_scale, seed, random_seed, use_lora, finetune_model_id, lora_model_id, base_model_id, lora_rank, lora_scale):
165
+ # Resolve model paths for generation
166
+ model_configs = load_model_configs(config_path)
167
+ finetune_config = model_configs.get(finetune_model_id, {})
168
+ finetune_local_dir = finetune_config.get('local_dir')
169
+ finetune_model_path = finetune_local_dir if finetune_local_dir and os.path.exists(finetune_local_dir) and any(os.path.isfile(os.path.join(finetune_local_dir, f)) for f in os.listdir(finetune_local_dir)) else finetune_model_id
170
+
171
+ lora_config = model_configs.get(lora_model_id, {})
172
+ lora_local_dir = lora_config.get('local_dir')
173
+ lora_model_path = lora_local_dir if lora_local_dir and os.path.exists(lora_local_dir) and any(os.path.isfile(os.path.join(lora_local_dir, f)) for f in os.listdir(lora_local_dir)) else lora_model_id
174
+
175
+ base_model_config = model_configs.get(base_model_id, {})
176
+ base_local_dir = base_model_config.get('local_dir')
177
+ base_model_path = base_local_dir if base_local_dir and os.path.exists(base_local_dir) and any(os.path.isfile(os.path.join(base_local_dir, f)) for f in os.listdir(base_local_dir)) else base_model_id
178
+
179
+ if not prompt:
180
+ return None, "Prompt cannot be empty."
181
+ if height % 8 != 0 or width % 8 != 0:
182
+ return None, "Height and width must be divisible by 8 (e.g., 256, 512, 1024)."
183
+ if num_inference_steps < 1 or num_inference_steps > 100:
184
+ return None, "Number of inference steps must be between 1 and 100."
185
+ if guidance_scale < 1.0 or guidance_scale > 20.0:
186
+ return None, "Guidance scale must be between 1.0 and 20.0."
187
+ if seed < 0 or seed > 4294967295:
188
+ return None, "Seed must be between 0 and 4294967295."
189
+ if use_lora and (not lora_model_path or not os.path.exists(lora_model_path) and not lora_model_path.startswith("danhtran2mind/")):
190
+ return None, f"LoRA model path {lora_model_path} does not exist or is invalid."
191
+ if use_lora and (not base_model_path or not os.path.exists(base_model_path) and not base_model_path.startswith("stabilityai/")):
192
+ return None, f"Base model path {base_model_path} does not exist or is invalid."
193
+ if not use_lora and (not finetune_model_path or not os.path.exists(finetune_model_path) and not finetune_model_path.startswith("danhtran2mind/")):
194
+ return None, f"Fine-tuned model path {finetune_model_path} does not exist or is invalid."
195
+ if use_lora and (lora_rank < 1 or lora_rank > 128):
196
+ return None, "LoRA rank must be between 1 and 128."
197
+ if use_lora and (lora_scale < 0.0 or lora_scale > 2.0):
198
+ return None, "LoRA scale must be between 0.0 and 2.0."
199
+
200
+ batch_size = 1
201
+ if random_seed:
202
+ seed = torch.randint(0, 4294967295, (1,)).item()
203
+ generator = torch.Generator(device=device).manual_seed(int(seed))
204
+
205
+ # Load models based on use_lora
206
+ if use_lora:
207
+ try:
208
+ pipe = StableDiffusionPipeline.from_pretrained(
209
+ base_model_path,
210
+ torch_dtype=dtype,
211
+ use_safetensors=True
212
+ )
213
+ pipe.load_lora_weights(lora_model_path, adapter_name="ghibli-lora", lora_scale=lora_scale)
214
+ pipe = pipe.to(device)
215
+ vae = pipe.vae
216
+ tokenizer = pipe.tokenizer
217
+ text_encoder = pipe.text_encoder
218
+ unet = pipe.unet
219
+ scheduler = PNDMScheduler.from_config(pipe.scheduler.config)
220
+ except Exception as e:
221
+ return None, f"Error loading LoRA model from {lora_model_path} or base model from {base_model_path}: {e}"
222
+ else:
223
+ try:
224
+ vae = AutoencoderKL.from_pretrained(finetune_model_path, subfolder="vae", torch_dtype=dtype).to(device)
225
+ tokenizer = CLIPTokenizer.from_pretrained(finetune_model_path, subfolder="tokenizer")
226
+ text_encoder = CLIPTextModel.from_pretrained(finetune_model_path, subfolder="text_encoder", torch_dtype=dtype).to(device)
227
+ unet = UNet2DConditionModel.from_pretrained(finetune_model_path, subfolder="unet", torch_dtype=dtype).to(device)
228
+ scheduler = PNDMScheduler.from_pretrained(finetune_model_path, subfolder="scheduler")
229
+ except Exception as e:
230
+ return None, f"Error loading fine-tuned model from {finetune_model_path}: {e}"
231
+
232
+ text_input = tokenizer(
233
+ [prompt], padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt"
234
+ )
235
+ with torch.no_grad():
236
+ text_embeddings = text_encoder(text_input.input_ids.to(device))[0].to(dtype=dtype)
237
+
238
+ max_length = text_input.input_ids.shape[-1]
239
+ uncond_input = tokenizer(
240
+ [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
241
+ )
242
+ with torch.no_grad():
243
+ uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0].to(dtype=dtype)
244
+
245
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
246
+
247
+ latents = torch.randn(
248
+ (batch_size, unet.config.in_channels, height // 8, width // 8),
249
+ generator=generator,
250
+ dtype=dtype,
251
+ device=device
252
+ )
253
+
254
+ scheduler.set_timesteps(num_inference_steps)
255
+ latents = latents * scheduler.init_noise_sigma
256
+
257
+ for t in tqdm(scheduler.timesteps, desc="Generating image"):
258
+ latent_model_input = torch.cat([latents] * 2)
259
+ latent_model_input = scheduler.scale_model_input(latent_model_input, t)
260
+
261
+ with torch.no_grad():
262
+ if device.type == "cuda":
263
+ with torch.autocast(device_type="cuda", dtype=torch.float16):
264
+ noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
265
+ else:
266
+ noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
267
+
268
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
269
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
270
+ latents = scheduler.step(noise_pred, t, latents).prev_sample
271
+
272
+ with torch.no_grad():
273
+ latents = latents / vae.config.scaling_factor
274
+ image = vae.decode(latents).sample
275
+
276
+ image = (image / 2 + 0.5).clamp(0, 1)
277
+ image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
278
+ image = (image * 255).round().astype("uint8")
279
+ pil_image = Image.fromarray(image[0])
280
+
281
+ # Success message includes LoRA Path and LoRA Scale when use_lora is True
282
+ if use_lora:
283
+ return pil_image, f"Image generated successfully! Seed used: {seed}, LoRA Path: {lora_model_path}, LoRA Scale: {lora_scale}"
284
+ return pil_image, f"Image generated successfully! Seed used: {seed}"
285
+
286
+ def load_example_image(prompt, height, width, num_inference_steps, guidance_scale, seed, image_path, use_lora, finetune_model_id, lora_model_id, base_model_id, lora_rank, lora_scale):
287
+ """
288
+ Load the image for the selected example and update input fields.
289
+ """
290
+ if image_path and Path(image_path).exists():
291
+ try:
292
+ image = Image.open(image_path)
293
+ return (
294
+ prompt, height, width, num_inference_steps, guidance_scale, seed, image,
295
+ use_lora, finetune_model_id, lora_model_id, base_model_id, lora_rank, lora_scale,
296
+ f"Loaded image: {image_path}"
297
+ )
298
+ except Exception as e:
299
+ return (
300
+ prompt, height, width, num_inference_steps, guidance_scale, seed, None,
301
+ use_lora, finetune_model_id, lora_model_id, base_model_id, lora_rank, lora_scale,
302
+ f"Error loading image: {e}"
303
+ )
304
+ return (
305
+ prompt, height, width, num_inference_steps, guidance_scale, seed, None,
306
+ use_lora, finetune_model_id, lora_model_id, base_model_id, lora_rank, lora_scale,
307
+ "No image available"
308
+ )
309
+
310
+ badges_text = r"""
311
+ <div style="text-align: left; font-size: 14px; display: flex; flex-direction: column; gap: 10px;">
312
+ <div style="display: flex; align-items: center; justify-content: left; gap: 8px;">
313
+ You can explore GitHub repository:
314
+ <a href="https://github.com/danhtran2mind/Ghibli-Stable-Diffusion-Synthesis">
315
+ <img src="https://img.shields.io/badge/GitHub-danhtran2mind%2FGhibli--Stable--Diffusion--Synthesis-blue?style=flat&logo=github" alt="GitHub Repo">
316
+ </a>. And you can explore HuggingFace Model Hub:
317
+ <a href="https://huggingface.co/spaces/danhtran2mind/Ghibli-Stable-Diffusion-2.1-Base-finetuning">
318
+ <img src="https://img.shields.io/badge/HuggingFace-danhtran2mind%2FGhibli--Stable--Diffusion--2.1--Base--finetuning-yellow?style=flat&logo=huggingface" alt="HuggingFace Space Demo">
319
+ </a>
320
+ and
321
+ <a href="https://huggingface.co/spaces/danhtran2mind/Ghibli-Stable-Diffusion-2.1-LoRA">
322
+ <img src="https://img.shields.io/badge/HuggingFace-danhtran2mind%2FGhibli--Stable--Diffusion--2.1--LoRA-yellow?style=flat&logo=huggingface" alt="HuggingFace Space Demo">
323
+ </a>
324
+ </div>
325
+ </div>
326
+ """.strip()
327
+
328
+ with gr.Blocks() as demo:
329
+ gr.Markdown("# Ghibli-Style Image Generator")
330
+ gr.Markdown(badges_text)
331
+ gr.Markdown("Generate images in Ghibli style using a fine-tuned Stable Diffusion model or Stable Diffusion 2.1 with LoRA weights. Select an example below to load a pre-generated image or enter a prompt to generate a new one.")
332
+ gr.Markdown("""**Note:** For CPU inference, execution time is long (e.g., for resolution 512 × 512 with 50 inference steps, time is approximately 1700 seconds).""")
333
+
334
+ with gr.Row():
335
+ with gr.Column():
336
+ prompt = gr.Textbox(label="Prompt", placeholder="e.g., 'a serene landscape in Ghibli style'")
337
+ with gr.Row():
338
+ width = gr.Slider(32, 4096, 512, step=8, label="Generation Width")
339
+ height = gr.Slider(32, 4096, 512, step=8, label="Generation Height")
340
+ with gr.Accordion("Advanced Options", open=False):
341
+ num_inference_steps = gr.Slider(1, 100, 50, step=1, label="Number of Inference Steps")
342
+ guidance_scale = gr.Slider(1.0, 20.0, 3.5, step=0.5, label="Guidance Scale")
343
+ seed = gr.Number(42, label="Seed (0 to 4294967295)")
344
+ random_seed = gr.Checkbox(label="Use Random Seed", value=False)
345
+ use_lora = gr.Checkbox(label="Use LoRA Weights", value=False)
346
+ finetune_model_path = gr.Dropdown(
347
+ label="Fine-tuned Model Path",
348
+ choices=finetune_model_ids,
349
+ value=finetune_model_id,
350
+ visible=not use_lora.value
351
+ )
352
+ lora_model_path = gr.Dropdown(
353
+ label="LoRA Model Path",
354
+ choices=lora_model_ids,
355
+ value=lora_model_id,
356
+ visible=use_lora.value
357
+ )
358
+ base_model_path = gr.Dropdown(
359
+ label="Base Model Path",
360
+ choices=base_model_ids,
361
+ value=base_model_id,
362
+ visible=use_lora.value
363
+ )
364
+
365
+ with gr.Group(visible=use_lora.value):
366
+ gr.Markdown("### LoRA Configuration")
367
+ lora_rank = gr.Slider(
368
+ 1, 128, 64, step=1,
369
+ label="LoRA Rank (controls model complexity)",
370
+ visible=use_lora.value,
371
+ info="Adjusts the rank of LoRA weights, affecting model complexity and memory usage."
372
+ )
373
+ lora_scale = gr.Slider(
374
+ 0.0, 2.0, 1.2, step=0.1,
375
+ label="LoRA Scale (controls weight influence)",
376
+ visible=use_lora.value,
377
+ info="Adjusts the influence of LoRA weights on the base model."
378
+ )
379
+ generate_btn = gr.Button("Generate Image")
380
+
381
+ with gr.Column():
382
+ output_image = gr.Image(label="Generated Image")
383
+ output_text = gr.Textbox(label="Status")
384
+
385
+ examples = get_examples("apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning")
386
+ gr.Examples(
387
+ examples=examples,
388
+ inputs=[prompt, height, width, num_inference_steps, guidance_scale, seed, output_image, use_lora, finetune_model_path, lora_model_path, base_model_path, lora_rank, lora_scale],
389
+ outputs=[prompt, height, width, num_inference_steps, guidance_scale, seed, output_image, use_lora, finetune_model_path, lora_model_path, base_model_path, lora_rank, lora_scale, output_text],
390
+ fn=load_example_image,
391
+ cache_examples=False
392
+ )
393
+
394
+ use_lora.change(
395
+ fn=update_model_path_visibility,
396
+ inputs=use_lora,
397
+ outputs=[lora_model_path, base_model_path, finetune_model_path, lora_rank, lora_scale]
398
+ )
399
+
400
+ generate_btn.click(
401
+ fn=generate_image,
402
+ inputs=[prompt, height, width, num_inference_steps, guidance_scale, seed, random_seed, use_lora, finetune_model_path, lora_model_path, base_model_path, lora_rank, lora_scale],
403
+ outputs=[output_image, output_text]
404
+ )
405
+
406
+ return demo
407
+
408
+ if __name__ == "__main__":
409
+ parser = argparse.ArgumentParser(description="Ghibli-Style Image Generator using a fine-tuned Stable Diffusion model or Stable Diffusion 2.1 with LoRA weights.")
410
+ parser.add_argument(
411
+ "--config_path",
412
+ type=str,
413
+ default="configs/model_ckpts.yaml",
414
+ help="Path to the model configuration YAML file."
415
+ )
416
+ parser.add_argument(
417
+ "--device",
418
+ type=str,
419
+ default="cuda" if torch.cuda.is_available() else "cpu",
420
+ help="Device to run the model on (e.g., 'cuda', 'cpu')."
421
+ )
422
+ parser.add_argument(
423
+ "--port",
424
+ type=int,
425
+ default=7860,
426
+ help="Port to run the Gradio app on."
427
+ )
428
+ parser.add_argument(
429
+ "--share",
430
+ action="store_true",
431
+ default=False,
432
+ help="Set to True for public sharing (Hugging Face Spaces)."
433
+ )
434
+
435
+ args = parser.parse_args()
436
+
437
+ demo = create_demo(args.config_path, args.device)
438
+ demo.launch(server_port=args.port, share=args.share)
apps/old4-gradio_app.py ADDED
@@ -0,0 +1,548 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ from pathlib import Path
4
+ import os
5
+ import gradio as gr
6
+ import torch
7
+ from PIL import Image
8
+ import numpy as np
9
+ from transformers import CLIPTextModel, CLIPTokenizer
10
+ from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler, StableDiffusionPipeline
11
+ from tqdm import tqdm
12
+ import yaml
13
+
14
+ def load_model_configs(config_path: str = "configs/model_ckpts.yaml") -> dict:
15
+ """
16
+ Load model configurations from a YAML file.
17
+ Returns a dictionary with model IDs and their details.
18
+ """
19
+ try:
20
+ with open(config_path, 'r') as f:
21
+ configs = yaml.safe_load(f)
22
+ return {cfg['model_id']: cfg for cfg in configs}
23
+ except (IOError, yaml.YAMLError) as e:
24
+ raise ValueError(f"Error loading {config_path}: {e}")
25
+
26
+ def get_examples(examples_dir: str = "apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning") -> list:
27
+
28
+ if not os.path.exists(examples_dir) or not os.path.isdir(examples_dir):
29
+ raise ValueError(f"Directory {examples_dir} does not exist or is not a directory")
30
+
31
+ all_examples_dir = [os.path.join(examples_dir, d) for d in os.listdir(examples_dir)
32
+ if os.path.isdir(os.path.join(examples_dir, d))]
33
+
34
+ ans = []
35
+ for example_dir in all_examples_dir:
36
+ config_path = os.path.join(example_dir, "config.json")
37
+ image_path = os.path.join(example_dir, "result.png")
38
+
39
+ if not os.path.isfile(config_path):
40
+ print(f"Warning: config.json not found in {example_dir}")
41
+ continue
42
+ if not os.path.isfile(image_path):
43
+ print(f"Warning: result.png not found in {example_dir}")
44
+ continue
45
+
46
+ try:
47
+ with open(config_path, 'r') as f:
48
+ example_dict = json.load(f)
49
+ except (json.JSONDecodeError, IOError) as e:
50
+ print(f"Error reading or parsing {config_path}: {e}")
51
+ continue
52
+
53
+ # Required keys for all configs
54
+ required_keys = ["prompt", "height", "width", "num_inference_steps", "guidance_scale", "seed", "image"]
55
+ if not all(key in example_dict for key in required_keys):
56
+ print(f"Warning: Missing required keys in {config_path}")
57
+ continue
58
+
59
+ if example_dict["image"] != "result.png":
60
+ print(f"Warning: Image key in {config_path} does not match 'result.png'")
61
+ continue
62
+
63
+ try:
64
+ use_lora = example_dict.get("use_lora", False)
65
+ example_list = [
66
+ example_dict["prompt"],
67
+ example_dict["height"],
68
+ example_dict["width"],
69
+ example_dict["num_inference_steps"],
70
+ example_dict["guidance_scale"],
71
+ example_dict["seed"],
72
+ image_path,
73
+ use_lora
74
+ ]
75
+
76
+ if use_lora:
77
+ # Additional required keys for LoRA config
78
+ lora_required_keys = ["lora_model_id", "base_model_id", "lora_rank", "lora_scale"]
79
+ if not all(key in example_dict for key in lora_required_keys):
80
+ print(f"Warning: Missing required LoRA keys in {config_path}")
81
+ continue
82
+
83
+ example_list.extend([
84
+ None, # finetune_model_id (not used for LoRA)
85
+ example_dict["lora_model_id"],
86
+ example_dict["base_model_id"],
87
+ example_dict["lora_rank"],
88
+ example_dict["lora_scale"]
89
+ ])
90
+ else:
91
+ # Additional required key for non-LoRA config
92
+ if "finetune_model_id" not in example_dict:
93
+ print(f"Warning: Missing finetune_model_id in {config_path}")
94
+ continue
95
+
96
+ example_list.extend([
97
+ example_dict["finetune_model_id"],
98
+ None, # lora_model_id
99
+ None, # base_model_id
100
+ None, # lora_rank
101
+ None # lora_scale
102
+ ])
103
+
104
+ ans.append(example_list)
105
+ except KeyError as e:
106
+ print(f"Error processing {config_path}: Missing key {e}")
107
+ continue
108
+
109
+ if not ans:
110
+ # Default example for non-LoRA
111
+ ans = [
112
+ ["a serene landscape in Ghibli style", 512, 512, 50, 3.5, 42, None, False,
113
+ "stabilityai/stable-diffusion-2-1-base",
114
+ None, None, None, None]
115
+ ]
116
+ # Default example for LoRA
117
+ ans.append(
118
+ ["a serene landscape in Ghibli style", 512, 512, 50, 3.5, 42, None, True,
119
+ None,
120
+ "stabilityai/stable-diffusion-2-1",
121
+ "stabilityai/stable-diffusion-2-1",
122
+ 64, 1.2]
123
+ )
124
+
125
+ return ans
126
+
127
+ def create_demo(
128
+ config_path: str = "configs/model_ckpts.yaml",
129
+ device: str = "cuda" if torch.cuda.is_available() else "cpu",
130
+ ):
131
+ # Load model configurations
132
+ model_configs = load_model_configs(config_path)
133
+
134
+ # Load model IDs from YAML
135
+ finetune_model_id = next((mid for mid, cfg in model_configs.items() if cfg.get('type') == 'full_finetuning'), None)
136
+ lora_model_id = next((mid for mid, cfg in model_configs.items() if cfg.get('type') == 'lora'), None)
137
+
138
+ if not finetune_model_id or not lora_model_id:
139
+ raise ValueError("Could not find full_finetuning or lora model IDs in the configuration file.")
140
+
141
+ # Determine finetune model path
142
+ finetune_config = model_configs.get(finetune_model_id, {})
143
+ finetune_local_dir = finetune_config.get('local_dir')
144
+ if finetune_local_dir and os.path.exists(finetune_local_dir) and any(os.path.isfile(os.path.join(finetune_local_dir, f)) for f in os.listdir(finetune_local_dir)):
145
+ finetune_model_path = finetune_local_dir
146
+ else:
147
+ print(f"Local model directory for fine-tuned model '{finetune_model_id}' does not exist or is empty at '{finetune_local_dir}'. Falling back to model ID.")
148
+ finetune_model_path = finetune_model_id
149
+
150
+ # Determine LoRA model path
151
+ lora_config = model_configs.get(lora_model_id, {})
152
+ lora_local_dir = lora_config.get('local_dir')
153
+ if lora_local_dir and os.path.exists(lora_local_dir) and any(os.path.isfile(os.path.join(lora_local_dir, f)) for f in os.listdir(lora_local_dir)):
154
+ lora_model_path = lora_local_dir
155
+ else:
156
+ print(f"Local model directory for LoRA model '{lora_model_id}' does not exist or is empty at '{lora_local_dir}'. Falling back to model ID.")
157
+ lora_model_path = lora_model_id
158
+
159
+ # Determine base model path
160
+ base_model_id = lora_config.get('base_model_id', 'stabilityai/stable-diffusion-2-1')
161
+ base_model_config = model_configs.get(base_model_id, {})
162
+ base_local_dir = base_model_config.get('local_dir')
163
+ if base_local_dir and os.path.exists(base_local_dir) and any(os.path.isfile(os.path.join(base_local_dir, f)) for f in os.listdir(base_local_dir)):
164
+ base_model_path = base_local_dir
165
+ else:
166
+ print(f"Local model directory for base model '{base_model_id}' does not exist or is empty at '{base_local_dir}'. Falling back to model ID.")
167
+ base_model_path = base_model_id
168
+
169
+ # Convert device string to torch.device
170
+ device = torch.device(device)
171
+ dtype = torch.float16 if device.type == "cuda" else torch.float32
172
+
173
+ # Extract model IDs for dropdown choices based on type
174
+ finetune_model_ids = [mid for mid, cfg in model_configs.items() if cfg.get('type') == 'full_finetuning']
175
+ lora_model_ids = [mid for mid, cfg in model_configs.items() if cfg.get('type') == 'lora']
176
+ base_model_ids = [model_configs[mid]['base_model_id'] for mid in model_configs if 'base_model_id' in model_configs[mid]]
177
+
178
+ def update_model_path_visibility(use_lora):
179
+ """
180
+ Update visibility of model path dropdowns and LoRA sliders based on use_lora checkbox.
181
+ """
182
+ if use_lora:
183
+ return gr.update(visible=True), gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=True)
184
+ return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)
185
+
186
+ def generate_image(prompt, height, width, num_inference_steps, guidance_scale, seed, random_seed, use_lora, finetune_model_id, lora_model_id, base_model_id, lora_rank, lora_scale):
187
+ # Resolve model paths for generation
188
+ model_configs = load_model_configs(config_path)
189
+ finetune_config = model_configs.get(finetune_model_id, {})
190
+ finetune_local_dir = finetune_config.get('local_dir')
191
+ finetune_model_path = finetune_local_dir if finetune_local_dir and os.path.exists(finetune_local_dir) and any(os.path.isfile(os.path.join(finetune_local_dir, f)) for f in os.listdir(finetune_local_dir)) else finetune_model_id
192
+
193
+ lora_config = model_configs.get(lora_model_id, {})
194
+ lora_local_dir = lora_config.get('local_dir')
195
+ lora_model_path = lora_local_dir if lora_local_dir and os.path.exists(lora_local_dir) and any(os.path.isfile(os.path.join(lora_local_dir, f)) for f in os.listdir(lora_local_dir)) else lora_model_id
196
+
197
+ base_model_config = model_configs.get(base_model_id, {})
198
+ base_local_dir = base_model_config.get('local_dir')
199
+ base_model_path = base_local_dir if base_local_dir and os.path.exists(base_local_dir) and any(os.path.isfile(os.path.join(base_local_dir, f)) for f in os.listdir(base_local_dir)) else base_model_id
200
+
201
+ if not prompt:
202
+ return None, "Prompt cannot be empty."
203
+ if height % 8 != 0 or width % 8 != 0:
204
+ return None, "Height and width must be divisible by 8 (e.g., 256, 512, 1024)."
205
+ if num_inference_steps < 1 or num_inference_steps > 100:
206
+ return None, "Number of inference steps must be between 1 and 100."
207
+ if guidance_scale < 1.0 or guidance_scale > 20.0:
208
+ return None, "Guidance scale must be between 1.0 and 20.0."
209
+ if seed < 0 or seed > 4294967295:
210
+ return None, "Seed must be between 0 and 4294967295."
211
+ if use_lora and (not lora_model_path or not os.path.exists(lora_model_path) and not lora_model_path.startswith("danhtran2mind/")):
212
+ return None, f"LoRA model path {lora_model_path} does not exist or is invalid."
213
+ if use_lora and (not base_model_path or not os.path.exists(base_model_path) and not base_model_path.startswith("stabilityai/")):
214
+ return None, f"Base model path {base_model_path} does not exist or is invalid."
215
+ if not use_lora and (not finetune_model_path or not os.path.exists(finetune_model_path) and not finetune_model_path.startswith("danhtran2mind/")):
216
+ return None, f"Fine-tuned model path {finetune_model_path} does not exist or is invalid."
217
+ if use_lora and (lora_rank < 1 or lora_rank > 128):
218
+ return None, "LoRA rank must be between 1 and 128."
219
+ if use_lora and (lora_scale < 0.0 or lora_scale > 2.0):
220
+ return None, "LoRA scale must be between 0.0 and 2.0."
221
+
222
+ batch_size = 1
223
+ if random_seed:
224
+ seed = torch.randint(0, 4294967295, (1,)).item()
225
+ generator = torch.Generator(device=device).manual_seed(int(seed))
226
+
227
+ # Load models based on use_lora
228
+ if use_lora:
229
+ try:
230
+ pipe = StableDiffusionPipeline.from_pretrained(
231
+ base_model_path,
232
+ torch_dtype=dtype,
233
+ use_safetensors=True
234
+ )
235
+ pipe.load_lora_weights(lora_model_path, adapter_name="ghibli-lora", lora_scale=lora_scale)
236
+ pipe = pipe.to(device)
237
+ vae = pipe.vae
238
+ tokenizer = pipe.tokenizer
239
+ text_encoder = pipe.text_encoder
240
+ unet = pipe.unet
241
+ scheduler = PNDMScheduler.from_config(pipe.scheduler.config)
242
+ except Exception as e:
243
+ return None, f"Error loading LoRA model from {lora_model_path} or base model from {base_model_path}: {e}"
244
+ else:
245
+ try:
246
+ vae = AutoencoderKL.from_pretrained(finetune_model_path, subfolder="vae", torch_dtype=dtype).to(device)
247
+ tokenizer = CLIPTokenizer.from_pretrained(finetune_model_path, subfolder="tokenizer")
248
+ text_encoder = CLIPTextModel.from_pretrained(finetune_model_path, subfolder="text_encoder", torch_dtype=dtype).to(device)
249
+ unet = UNet2DConditionModel.from_pretrained(finetune_model_path, subfolder="unet", torch_dtype=dtype).to(device)
250
+ scheduler = PNDMScheduler.from_pretrained(finetune_model_path, subfolder="scheduler")
251
+ except Exception as e:
252
+ return None, f"Error loading fine-tuned model from {finetune_model_path}: {e}"
253
+
254
+ text_input = tokenizer(
255
+ [prompt], padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt"
256
+ )
257
+ with torch.no_grad():
258
+ text_embeddings = text_encoder(text_input.input_ids.to(device))[0].to(dtype=dtype)
259
+
260
+ max_length = text_input.input_ids.shape[-1]
261
+ uncond_input = tokenizer(
262
+ [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
263
+ )
264
+ with torch.no_grad():
265
+ uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0].to(dtype=dtype)
266
+
267
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
268
+
269
+ latents = torch.randn(
270
+ (batch_size, unet.config.in_channels, height // 8, width // 8),
271
+ generator=generator,
272
+ dtype=dtype,
273
+ device=device
274
+ )
275
+
276
+ scheduler.set_timesteps(num_inference_steps)
277
+ latents = latents * scheduler.init_noise_sigma
278
+
279
+ for t in tqdm(scheduler.timesteps, desc="Generating image"):
280
+ latent_model_input = torch.cat([latents] * 2)
281
+ latent_model_input = scheduler.scale_model_input(latent_model_input, t)
282
+
283
+ with torch.no_grad():
284
+ if device.type == "cuda":
285
+ with torch.autocast(device_type="cuda", dtype=torch.float16):
286
+ noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
287
+ else:
288
+ noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
289
+
290
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
291
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
292
+ latents = scheduler.step(noise_pred, t, latents).prev_sample
293
+
294
+ with torch.no_grad():
295
+ latents = latents / vae.config.scaling_factor
296
+ image = vae.decode(latents).sample
297
+
298
+ image = (image / 2 + 0.5).clamp(0, 1)
299
+ image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
300
+ image = (image * 255).round().astype("uint8")
301
+ pil_image = Image.fromarray(image[0])
302
+
303
+ # Success message includes LoRA Path and LoRA Scale when use_lora is True
304
+ if use_lora:
305
+ return pil_image, f"Image generated successfully! Seed used: {seed}, LoRA Path: {lora_model_path}, LoRA Scale: {lora_scale}"
306
+ return pil_image, f"Image generated successfully! Seed used: {seed}"
307
+
308
+ def load_example_image(prompt, height, width, num_inference_steps, guidance_scale,
309
+ seed, image_path, use_lora, finetune_model_id, lora_model_id,
310
+ base_model_id, lora_rank, lora_scale):
311
+ """
312
+ Load the image for the selected example and update input fields.
313
+ """
314
+ if image_path and Path(image_path).exists():
315
+ try:
316
+ image = Image.open(image_path)
317
+ return (
318
+ prompt, height, width, num_inference_steps, guidance_scale, seed, image,
319
+ use_lora, finetune_model_id, lora_model_id, base_model_id, lora_rank, lora_scale,
320
+ f"Loaded image: {image_path}"
321
+ )
322
+ except Exception as e:
323
+ return (
324
+ prompt, height, width, num_inference_steps, guidance_scale, seed, None,
325
+ use_lora, finetune_model_id, lora_model_id, base_model_id, lora_rank, lora_scale,
326
+ f"Error loading image: {e}"
327
+ )
328
+ return (
329
+ prompt, height, width, num_inference_steps, guidance_scale, seed, None,
330
+ use_lora, finetune_model_id, lora_model_id, base_model_id, lora_rank, lora_scale,
331
+ "No image available"
332
+ )
333
+
334
+ badges_text = r"""
335
+ <div style="text-align: left; font-size: 14px; display: flex; flex-direction: column; gap: 10px;">
336
+ <div style="display: flex; align-items: center; justify-content: left; gap: 8px;">
337
+ You can explore GitHub repository:
338
+ <a href="https://github.com/danhtran2mind/Ghibli-Stable-Diffusion-Synthesis">
339
+ <img src="https://img.shields.io/badge/GitHub-danhtran2mind%2FGhibli--Stable--Diffusion--Synthesis-blue?style=flat&logo=github" alt="GitHub Repo">
340
+ </a>. And you can explore HuggingFace Model Hub:
341
+ <a href="https://huggingface.co/spaces/danhtran2mind/Ghibli-Stable-Diffusion-2.1-Base-finetuning">
342
+ <img src="https://img.shields.io/badge/HuggingFace-danhtran2mind%2FGhibli--Stable--Diffusion--2.1--Base--finetuning-yellow?style=flat&logo=huggingface" alt="HuggingFace Space Demo">
343
+ </a>
344
+ and
345
+ <a href="https://huggingface.co/spaces/danhtran2mind/Ghibli-Stable-Diffusion-2.1-LoRA">
346
+ <img src="https://img.shields.io/badge/HuggingFace-danhtran2mind%2FGhibli--Stable--Diffusion--2.1--LoRA-yellow?style=flat&logo=huggingface" alt="HuggingFace Space Demo">
347
+ </a>
348
+ </div>
349
+ </div>
350
+ """.strip()
351
+
352
+ with gr.Blocks() as demo:
353
+ # Main Layout: Split into Input and Output Columns
354
+ with gr.Row():
355
+ # Input Column
356
+ with gr.Column(scale=1):
357
+ gr.Markdown("## Image Generation Settings")
358
+
359
+ # Prompt Input
360
+ prompt = gr.Textbox(
361
+ label="Prompt",
362
+ placeholder="e.g., 'a serene landscape in Ghibli style'",
363
+ lines=2
364
+ )
365
+
366
+ # Image Dimensions
367
+ with gr.Group():
368
+ gr.Markdown("### Image Dimensions")
369
+ with gr.Row():
370
+ width = gr.Slider(
371
+ minimum=32,
372
+ maximum=4096,
373
+ value=512,
374
+ step=8,
375
+ label="Width"
376
+ )
377
+ height = gr.Slider(
378
+ minimum=32,
379
+ maximum=4096,
380
+ value=512,
381
+ step=8,
382
+ label="Height"
383
+ )
384
+
385
+ # Advanced Settings Accordion
386
+ with gr.Accordion("Advanced Settings", open=False):
387
+ num_inference_steps = gr.Slider(
388
+ minimum=1,
389
+ maximum=100,
390
+ value=50,
391
+ step=1,
392
+ label="Inference Steps",
393
+ info="Higher steps improve quality but increase generation time."
394
+ )
395
+ guidance_scale = gr.Slider(
396
+ minimum=1.0,
397
+ maximum=20.0,
398
+ value=3.5,
399
+ step=0.5,
400
+ label="Guidance Scale",
401
+ info="Controls how closely the image follows the prompt."
402
+ )
403
+ lora_rank = gr.Slider(
404
+ minimum=1,
405
+ maximum=128,
406
+ value=64,
407
+ step=1,
408
+ visible=False, # Initially hidden
409
+ label="LoRA Rank",
410
+ info="Controls model complexity and memory usage."
411
+ )
412
+ lora_scale = gr.Slider(
413
+ minimum=0.0,
414
+ maximum=2.0,
415
+ value=1.2,
416
+ step=0.1,
417
+ visible=False, # Initially hidden
418
+ label="LoRA Scale",
419
+ info="Adjusts the influence of LoRA weights."
420
+ )
421
+ random_seed = gr.Checkbox(
422
+ label="Use Random Seed",
423
+ value=False
424
+ )
425
+ seed = gr.Slider(
426
+ minimum=0,
427
+ maximum=4294967295,
428
+ value=42,
429
+ step=1,
430
+ label="Seed (0–4294967295)",
431
+ info="Set a specific seed for reproducible results."
432
+ )
433
+
434
+ # Model Selection
435
+ with gr.Group():
436
+ gr.Markdown("### Model Configuration")
437
+ use_lora = gr.Checkbox(
438
+ label="Use LoRA Weights",
439
+ value=False,
440
+ info="Enable to use LoRA weights with a base model."
441
+ )
442
+
443
+ # Model Path Dropdowns
444
+ finetune_model_path = gr.Dropdown(
445
+ label="Fine-tuned Model",
446
+ choices=finetune_model_ids,
447
+ value=finetune_model_id,
448
+ visible=not use_lora.value
449
+ )
450
+ lora_model_path = gr.Dropdown(
451
+ label="LoRA Model",
452
+ choices=lora_model_ids,
453
+ value=lora_model_id,
454
+ visible=use_lora.value
455
+ )
456
+ base_model_path = gr.Dropdown(
457
+ label="Base Model",
458
+ choices=base_model_ids,
459
+ value=base_model_id,
460
+ visible=use_lora.value
461
+ )
462
+
463
+ # Generate Button
464
+ generate_btn = gr.Button("Generate Image", variant="primary")
465
+
466
+ # Output Column
467
+ with gr.Column(scale=1):
468
+ gr.Markdown("## Generated Result")
469
+ output_image = gr.Image(
470
+ label="Generated Image",
471
+ interactive=False,
472
+ height=512
473
+ )
474
+ output_text = gr.Textbox(
475
+ label="Generation Status",
476
+ interactive=False,
477
+ lines=3
478
+ )
479
+
480
+ # Examples Section
481
+ gr.Markdown("## Try an Example")
482
+ examples = get_examples("apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning")
483
+ gr.Examples(
484
+ examples=examples,
485
+ inputs=[
486
+ prompt, height, width, num_inference_steps, guidance_scale, seed,
487
+ use_lora, finetune_model_path, lora_model_path, base_model_path,
488
+ lora_rank, lora_scale
489
+ ],
490
+ outputs=[
491
+ prompt, height, width, num_inference_steps, guidance_scale, seed,
492
+ output_image, use_lora, finetune_model_path, lora_model_path,
493
+ base_model_path, lora_rank, lora_scale, output_text
494
+ ],
495
+ fn=load_example_image,
496
+ cache_examples=False
497
+ )
498
+
499
+ # Event Handlers
500
+ use_lora.change(
501
+ fn=update_model_path_visibility,
502
+ inputs=use_lora,
503
+ outputs=[lora_model_path, base_model_path, finetune_model_path, lora_rank, lora_scale]
504
+ )
505
+
506
+ generate_btn.click(
507
+ fn=generate_image,
508
+ inputs=[
509
+ prompt, height, width, num_inference_steps, guidance_scale, seed,
510
+ random_seed, use_lora, finetune_model_path, lora_model_path,
511
+ base_model_path, lora_rank, lora_scale
512
+ ],
513
+ outputs=[output_image, output_text]
514
+ )
515
+
516
+ return demo
517
+
518
+ if __name__ == "__main__":
519
+ parser = argparse.ArgumentParser(description="Ghibli-Style Image Generator using a fine-tuned Stable Diffusion model or Stable Diffusion 2.1 with LoRA weights.")
520
+ parser.add_argument(
521
+ "--config_path",
522
+ type=str,
523
+ default="configs/model_ckpts.yaml",
524
+ help="Path to the model configuration YAML file."
525
+ )
526
+ parser.add_argument(
527
+ "--device",
528
+ type=str,
529
+ default="cuda" if torch.cuda.is_available() else "cpu",
530
+ help="Device to run the model on (e.g., 'cuda', 'cpu')."
531
+ )
532
+ parser.add_argument(
533
+ "--port",
534
+ type=int,
535
+ default=7860,
536
+ help="Port to run the Gradio app on."
537
+ )
538
+ parser.add_argument(
539
+ "--share",
540
+ action="store_true",
541
+ default=False,
542
+ help="Set to True for public sharing (Hugging Face Spaces)."
543
+ )
544
+
545
+ args = parser.parse_args()
546
+
547
+ demo = create_demo(args.config_path, args.device)
548
+ demo.launch(server_port=args.port, share=args.share)
apps/old5-gradio_app.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ from typing import Union, List
4
+ from pathlib import Path
5
+ import os
6
+ import gradio as gr
7
+ import torch
8
+ from PIL import Image
9
+ import numpy as np
10
+ from transformers import CLIPTextModel, CLIPTokenizer
11
+ from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler, StableDiffusionPipeline
12
+ from tqdm import tqdm
13
+ import yaml
14
+
15
+ def load_model_configs(config_path: str = "configs/model_ckpts.yaml") -> dict:
16
+ with open(config_path, 'r') as f:
17
+ return {cfg['model_id']: cfg for cfg in yaml.safe_load(f)}
18
+
19
+ def get_examples(examples_dir: Union[str, List[str]] = None, use_lora: bool = None) -> List:
20
+ directories = [examples_dir] if isinstance(examples_dir, str) else examples_dir or []
21
+ valid_dirs = [d for d in directories if os.path.isdir(d)]
22
+ if not valid_dirs:
23
+ return get_provided_examples(use_lora)
24
+
25
+ examples = []
26
+ for dir_path in valid_dirs:
27
+ for subdir in sorted(os.path.join(dir_path, d) for d in os.listdir(dir_path) if os.path.isdir(os.path.join(dir_path, d))):
28
+ config_path = os.path.join(subdir, "config.json")
29
+ image_path = os.path.join(subdir, "result.png")
30
+ if not (os.path.isfile(config_path) and os.path.isfile(image_path)):
31
+ continue
32
+
33
+ with open(config_path, 'r') as f:
34
+ config = json.load(f)
35
+
36
+ required_keys = ["prompt", "height", "width", "num_inference_steps", "guidance_scale", "seed", "image"]
37
+ if config.get("use_lora", False):
38
+ required_keys.extend(["lora_model_id", "base_model_id", "lora_rank", "lora_scale"])
39
+ else:
40
+ required_keys.append("finetune_model_id")
41
+
42
+ if set(required_keys) - set(config.keys()) or config["image"] != "result.png":
43
+ continue
44
+
45
+ try:
46
+ image = Image.open(image_path)
47
+ except Exception:
48
+ continue
49
+
50
+ if use_lora is not None and config.get("use_lora", False) != use_lora:
51
+ continue
52
+
53
+ example = [config["prompt"], config["height"], config["width"], config["num_inference_steps"],
54
+ config["guidance_scale"], config["seed"], image]
55
+ example.extend([config["lora_model_id"], config["base_model_id"], config["lora_rank"], config["lora_scale"]]
56
+ if config.get("use_lora", False) else [config["finetune_model_id"]])
57
+ examples.append(example)
58
+
59
+ return examples or get_provided_examples(use_lora)
60
+
61
+ def get_provided_examples(use_lora: bool = False) -> list:
62
+ example_path = f"apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-{'LoRA' if use_lora else 'Base-finetuning'}/1/result.png"
63
+ image = Image.open(example_path) if os.path.exists(example_path) else None
64
+ return [[
65
+ "a cat is laying on a sofa in Ghibli style" if use_lora else "a serene landscape in Ghibli style",
66
+ 512, 768 if use_lora else 512, 100 if use_lora else 50, 10.0 if use_lora else 3.5, 789 if use_lora else 42,
67
+ image, "danhtran2mind/Ghibli-Stable-Diffusion-2.1-LoRA" if use_lora else "danhtran2mind/Ghibli-Stable-Diffusion-2.1-Base-finetuning",
68
+ "stabilityai/stable-diffusion-2-1" if use_lora else None, 64 if use_lora else None, 0.9 if use_lora else None
69
+ ]]
70
+
71
+ def create_demo(config_path: str = "configs/model_ckpts.yaml", device: str = "cuda" if torch.cuda.is_available() else "cpu"):
72
+ model_configs = load_model_configs(config_path)
73
+ finetune_model_id = next((mid for mid, cfg in model_configs.items() if cfg.get('type') == 'full_finetuning'), None)
74
+ lora_model_id = next((mid for mid, cfg in model_configs.items() if cfg.get('type') == 'lora'), None)
75
+
76
+ if not finetune_model_id or not lora_model_id:
77
+ raise ValueError("Missing model IDs in config.")
78
+
79
+ finetune_model_path = model_configs[finetune_model_id].get('local_dir', finetune_model_id)
80
+ lora_model_path = model_configs[lora_model_id].get('local_dir', lora_model_id)
81
+ base_model_id = model_configs[lora_model_id].get('base_model_id', 'stabilityai/stable-diffusion-2-1')
82
+ base_model_path = model_configs.get(base_model_id, {}).get('local_dir', base_model_id)
83
+
84
+ device = torch.device(device)
85
+ dtype = torch.float16 if device.type == "cuda" else torch.float32
86
+
87
+ def generate_image(prompt, height, width, num_inference_steps, guidance_scale, seed, random_seed, use_lora,
88
+ finetune_model_id, lora_model_id, base_model_id, lora_rank, lora_scale):
89
+ if not prompt or height % 8 != 0 or width % 8 != 0 or num_inference_steps not in range(1, 101) or \
90
+ guidance_scale < 1.0 or guidance_scale > 20.0 or seed < 0 or seed > 4294967295 or \
91
+ (use_lora and (lora_rank < 1 or lora_rank > 128 or lora_scale < 0.0 or lora_scale > 2.0)):
92
+ return None, "Invalid input parameters."
93
+
94
+ model_configs = load_model_configs(config_path)
95
+ finetune_model_path = model_configs.get(finetune_model_id, {}).get('local_dir', finetune_model_id)
96
+ lora_model_path = model_configs.get(lora_model_id, {}).get('local_dir', lora_model_id)
97
+ base_model_path = model_configs.get(base_model_id, {}).get('local_dir', base_model_id)
98
+
99
+ generator = torch.Generator(device=device).manual_seed(torch.randint(0, 4294967295, (1,)).item() if random_seed else int(seed))
100
+
101
+ try:
102
+ if use_lora:
103
+ pipe = StableDiffusionPipeline.from_pretrained(base_model_path, torch_dtype=dtype, use_safetensors=True)
104
+ pipe.load_lora_weights(lora_model_path, adapter_name="ghibli-lora", lora_scale=lora_scale)
105
+ pipe = pipe.to(device)
106
+ vae, tokenizer, text_encoder, unet, scheduler = pipe.vae, pipe.tokenizer, pipe.text_encoder, pipe.unet, PNDMScheduler.from_config(pipe.scheduler.config)
107
+ else:
108
+ vae = AutoencoderKL.from_pretrained(finetune_model_path, subfolder="vae", torch_dtype=dtype).to(device)
109
+ tokenizer = CLIPTokenizer.from_pretrained(finetune_model_path, subfolder="tokenizer")
110
+ text_encoder = CLIPTextModel.from_pretrained(finetune_model_path, subfolder="text_encoder", torch_dtype=dtype).to(device)
111
+ unet = UNet2DConditionModel.from_pretrained(finetune_model_path, subfolder="unet", torch_dtype=dtype).to(device)
112
+ scheduler = PNDMScheduler.from_pretrained(finetune_model_path, subfolder="scheduler")
113
+
114
+ text_input = tokenizer([prompt], padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
115
+ text_embeddings = text_encoder(text_input.input_ids.to(device))[0].to(dtype=dtype)
116
+
117
+ uncond_input = tokenizer([""] * 1, padding="max_length", max_length=text_input.input_ids.shape[-1], return_tensors="pt")
118
+ uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0].to(dtype=dtype)
119
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
120
+
121
+ latents = torch.randn((1, unet.config.in_channels, height // 8, width // 8), generator=generator, dtype=dtype, device=device)
122
+ scheduler.set_timesteps(num_inference_steps)
123
+ latents = latents * scheduler.init_noise_sigma
124
+
125
+ for t in tqdm(scheduler.timesteps, desc="Generating image"):
126
+ latent_model_input = torch.cat([latents] * 2)
127
+ latent_model_input = scheduler.scale_model_input(latent_model_input, t)
128
+ noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
129
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
130
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
131
+ latents = scheduler.step(noise_pred, t, latents).prev_sample
132
+
133
+ image = vae.decode(latents / vae.config.scaling_factor).sample
134
+ image = (image / 2 + 0.5).clamp(0, 1).detach().cpu().permute(0, 2, 3, 1).numpy()
135
+ pil_image = Image.fromarray((image[0] * 255).round().astype("uint8"))
136
+
137
+ if use_lora:
138
+ del pipe
139
+ else:
140
+ del vae, tokenizer, text_encoder, unet, scheduler
141
+ torch.cuda.empty_cache()
142
+
143
+ return pil_image, f"Generated image successfully! Seed used: {seed}"
144
+ except Exception as e:
145
+ return None, f"Failed to generate image: {e}"
146
+
147
+ def load_example_image_full_finetuning(prompt, height, width, num_inference_steps, guidance_scale, seed, image, finetune_model_id):
148
+ return prompt, height, width, num_inference_steps, guidance_scale, seed, image, finetune_model_id, "Loaded example successfully"
149
+
150
+ def load_example_image_lora(prompt, height, width, num_inference_steps, guidance_scale, seed, image, lora_model_id, base_model_id, lora_rank, lora_scale):
151
+ return prompt, height, width, num_inference_steps, guidance_scale, seed, image, lora_model_id, base_model_id or "stabilityai/stable-diffusion-2-1", lora_rank or 64, lora_scale or 1.2, "Loaded example successfully"
152
+
153
+ badges_text = """
154
+ <div style="text-align: left; font-size: 14px; display: flex; flex-direction: column; gap: 10px;">
155
+ <div style="display: flex; align-items: center; justify-content: left; gap: 8px;">
156
+ GitHub: <a href="https://github.com/danhtran2mind/Ghibli-Stable-Diffusion-Synthesis">
157
+ <img src="https://img.shields.io/badge/GitHub-danhtran2mind%2FGhibli--Stable--Diffusion--Synthesis-blue?style=flat&logo=github" alt="GitHub Repo">
158
+ </a> HuggingFace:
159
+ <a href="https://huggingface.co/spaces/danhtran2mind/Ghibli-Stable-Diffusion-2.1-Base-finetuning">
160
+ <img src="https://img.shields.io/badge/HuggingFace-danhtran2mind%2FGhibli--Stable--Diffusion--2.1--Base--finetuning-yellow?style=flat&logo=huggingface" alt="HuggingFace Space Demo">
161
+ </a>
162
+ <a href="https://huggingface.co/spaces/danhtran2mind/Ghibli-Stable-Diffusion-2.1-LoRA">
163
+ <img src="https://img.shields.io/badge/HuggingFace-danhtran2mind%2FGhibli--Stable--Diffusion--2.1--LoRA-yellow?style=flat&logo=huggingface" alt="HuggingFace Space Demo">
164
+ </a>
165
+ </div>
166
+ </div>
167
+ """
168
+
169
+ custom_css = open("apps/gradio_app/static/styles.css", "r").read() if os.path.exists("apps/gradio_app/static/styles.css") else ""
170
+
171
+ examples_full_finetuning = get_examples("apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning", use_lora=False)
172
+ examples_lora = get_examples("apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-LoRA", use_lora=True)
173
+
174
+ with gr.Blocks(css=custom_css, theme="ocean") as demo:
175
+ gr.Markdown("## Ghibli-Style Image Generator")
176
+ with gr.Tabs():
177
+ with gr.Tab(label="Full Finetuning"):
178
+ with gr.Row():
179
+ with gr.Column(scale=1):
180
+ gr.Markdown("### Image Generation Settings")
181
+ prompt_ft = gr.Textbox(label="Prompt", placeholder="e.g., 'a serene landscape in Ghibli style'", lines=2)
182
+ with gr.Group():
183
+ gr.Markdown("#### Image Dimensions")
184
+ with gr.Row():
185
+ width_ft = gr.Slider(32, 4096, 512, step=8, label="Width")
186
+ height_ft = gr.Slider(32, 4096, 512, step=8, label="Height")
187
+ with gr.Accordion("Advanced Settings", open=False):
188
+ num_inference_steps_ft = gr.Slider(1, 100, 50, step=1, label="Inference Steps")
189
+ guidance_scale_ft = gr.Slider(1.0, 20.0, 3.5, step=0.5, label="Guidance Scale")
190
+ random_seed_ft = gr.Checkbox(label="Use Random Seed")
191
+ seed_ft = gr.Slider(0, 4294967295, 42, step=1, label="Seed")
192
+ gr.Markdown("#### Model Configuration")
193
+ finetune_model_path_ft = gr.Dropdown(label="Fine-tuned Model", choices=[mid for mid, cfg in model_configs.items() if cfg.get('type') == 'full_finetuning'], value=finetune_model_id)
194
+ with gr.Column(scale=1):
195
+ gr.Markdown("### Generated Result")
196
+ output_image_ft = gr.Image(label="Generated Image", interactive=False, height=512)
197
+ output_text_ft = gr.Textbox(label="Status", interactive=False, lines=3)
198
+ generate_btn_ft = gr.Button("Generate Image", variant="primary")
199
+ stop_btn_ft = gr.Button("Stop Generation")
200
+ gr.Markdown("### Examples for Full Finetuning")
201
+ gr.Examples(examples=examples_full_finetuning, inputs=[prompt_ft, height_ft, width_ft, num_inference_steps_ft, guidance_scale_ft, seed_ft, output_image_ft, finetune_model_path_ft],
202
+ outputs=[prompt_ft, height_ft, width_ft, num_inference_steps_ft, guidance_scale_ft, seed_ft, output_image_ft, finetune_model_path_ft, output_text_ft],
203
+ fn=load_example_image_full_finetuning, cache_examples=False, examples_per_page=4)
204
+
205
+ with gr.Tab(label="LoRA"):
206
+ with gr.Row():
207
+ with gr.Column(scale=1):
208
+ gr.Markdown("### Image Generation Settings")
209
+ prompt_lora = gr.Textbox(label="Prompt", placeholder="e.g., 'a serene landscape in Ghibli style'", lines=2)
210
+ with gr.Group():
211
+ gr.Markdown("#### Image Dimensions")
212
+ with gr.Row():
213
+ width_lora = gr.Slider(32, 4096, 512, step=8, label="Width")
214
+ height_lora = gr.Slider(32, 4096, 512, step=8, label="Height")
215
+ with gr.Accordion("Advanced Settings", open=False):
216
+ num_inference_steps_lora = gr.Slider(1, 100, 50, step=1, label="Inference Steps")
217
+ guidance_scale_lora = gr.Slider(1.0, 20.0, 3.5, step=0.5, label="Guidance Scale")
218
+ lora_rank_lora = gr.Slider(1, 128, 64, step=1, label="LoRA Rank")
219
+ lora_scale_lora = gr.Slider(0.0, 2.0, 1.2, step=0.1, label="LoRA Scale")
220
+ random_seed_lora = gr.Checkbox(label="Use Random Seed")
221
+ seed_lora = gr.Slider(0, 4294967295, 42, step=1, label="Seed")
222
+ gr.Markdown("#### Model Configuration")
223
+ lora_model_path_lora = gr.Dropdown(label="LoRA Model", choices=[mid for mid, cfg in model_configs.items() if cfg.get('type') == 'lora'], value=lora_model_id)
224
+ base_model_path_lora = gr.Dropdown(label="Base Model", choices=[model_configs[mid].get('base_model_id') for mid in model_configs if model_configs[mid].get('base_model_id')], value=base_model_id)
225
+ with gr.Column(scale=1):
226
+ gr.Markdown("### Generated Result")
227
+ output_image_lora = gr.Image(label="Generated Image", interactive=False, height=512)
228
+ output_text_lora = gr.Textbox(label="Status", interactive=False, lines=3)
229
+ generate_btn_lora = gr.Button("Generate Image", variant="primary")
230
+ stop_btn_lora = gr.Button("Stop Generation")
231
+ gr.Markdown("### Examples for LoRA")
232
+ gr.Examples(examples=examples_lora, inputs=[prompt_lora, height_lora, width_lora, num_inference_steps_lora, guidance_scale_lora, seed_lora, output_image_lora, lora_model_path_lora, base_model_path_lora, lora_rank_lora, lora_scale_lora],
233
+ outputs=[prompt_lora, height_lora, width_lora, num_inference_steps_lora, guidance_scale_lora, seed_lora, output_image_lora, lora_model_path_lora, base_model_path_lora, lora_rank_lora, lora_scale_lora, output_text_lora],
234
+ fn=load_example_image_lora, cache_examples=False, examples_per_page=4)
235
+
236
+ gr.Markdown(badges_text)
237
+
238
+ generate_event_ft = generate_btn_ft.click(fn=generate_image, inputs=[prompt_ft, height_ft, width_ft, num_inference_steps_ft, guidance_scale_ft, seed_ft, random_seed_ft, gr.State(False), finetune_model_path_ft, gr.State(None), gr.State(None), gr.State(None), gr.State(None)],
239
+ outputs=[output_image_ft, output_text_ft])
240
+ generate_event_lora = generate_btn_lora.click(fn=generate_image, inputs=[prompt_lora, height_lora, width_lora, num_inference_steps_lora, guidance_scale_lora, seed_lora, random_seed_lora, gr.State(True), gr.State(None), lora_model_path_lora, base_model_path_lora, lora_rank_lora, lora_scale_lora],
241
+ outputs=[output_image_lora, output_text_lora])
242
+
243
+ stop_btn_ft.click(fn=None, inputs=None, outputs=None, cancels=[generate_event_ft])
244
+ stop_btn_lora.click(fn=None, inputs=None, outputs=None, cancels=[generate_event_lora])
245
+
246
+ demo.unload(lambda: torch.cuda.empty_cache())
247
+
248
+ return demo
249
+
250
+ if __name__ == "__main__":
251
+ parser = argparse.ArgumentParser(description="Ghibli-Style Image Generator")
252
+ parser.add_argument("--config_path", type=str, default="configs/model_ckpts.yaml")
253
+ parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
254
+ parser.add_argument("--port", type=int, default=7860)
255
+ parser.add_argument("--share", action="store_true")
256
+ args = parser.parse_args()
257
+ demo = create_demo(args.config_path, args.device)
258
+ demo.launch(server_port=args.port, share=args.share)
assets/.gitkeep ADDED
@@ -0,0 +1 @@
 
 
1
+
assets/demo_image.png ADDED

Git LFS Details

  • SHA256: d37fd25c2c25490cc4556ae7493491c7dea30bbb60753ac1a59bf8aa8e9191fe
  • Pointer size: 131 Bytes
  • Size of remote file: 467 kB
assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning/1/config.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "prompt": "a serene landscape in Ghibli style",
3
+ "height": 256,
4
+ "width": 512,
5
+ "num_inference_steps": 50,
6
+ "guidance_scale": 3.5,
7
+ "seed": 42,
8
+ "image": "result.png",
9
+ "use_lora": false,
10
+ "finetune_model_id": "danhtran2mind/Ghibli-Stable-Diffusion-2.1-Base-finetuning"
11
+ }
assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning/1/result.png ADDED

Git LFS Details

  • SHA256: 8a955ecacd6b904093b65a7328bb1fdfc874f0866766e6f6d09bc73551a80d30
  • Pointer size: 131 Bytes
  • Size of remote file: 198 kB
assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning/2/config.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "prompt": "Donald Trump",
3
+ "height": 512,
4
+ "width": 512,
5
+ "num_inference_steps": 100,
6
+ "guidance_scale": 9,
7
+ "seed": 200,
8
+ "image": "result.png",
9
+ "use_lora": false,
10
+ "finetune_model_id": "danhtran2mind/Ghibli-Stable-Diffusion-2.1-Base-finetuning"
11
+ }
assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning/2/result.png ADDED

Git LFS Details

  • SHA256: 3e0d8bab61ede83e5e05171b93f5aa781780ee43c955bb30f95af8554587e9bd
  • Pointer size: 131 Bytes
  • Size of remote file: 232 kB
assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning/3/config.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "prompt": "a dancer in Ghibli style",
3
+ "height": 384,
4
+ "width": 192,
5
+ "num_inference_steps": 50,
6
+ "guidance_scale": 15.5,
7
+ "seed": 4223,
8
+ "image": "result.png",
9
+ "use_lora": false,
10
+ "finetune_model_id": "danhtran2mind/Ghibli-Stable-Diffusion-2.1-Base-finetuning"
11
+ }
assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning/3/result.png ADDED

Git LFS Details

  • SHA256: 5ef6e36606a3cfbb73a0a2a2a08b80c70e6405ddebb686d9db6108a3eed4ecb0
  • Pointer size: 131 Bytes
  • Size of remote file: 164 kB
assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning/4/config.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "prompt": "Ghibli style, the peace beach",
3
+ "height": 1024,
4
+ "width": 2048,
5
+ "num_inference_steps": 100,
6
+ "guidance_scale": 7.5,
7
+ "seed": 5678,
8
+ "image": "result.png",
9
+ "use_lora": false,
10
+ "finetune_model_id": "danhtran2mind/Ghibli-Stable-Diffusion-2.1-Base-finetuning"
11
+ }
assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning/4/result.png ADDED

Git LFS Details

  • SHA256: 258a57cac793da71ede5b5ecf4d752a747aee3d9022ef61947cc4e82fe8d7f51
  • Pointer size: 132 Bytes
  • Size of remote file: 3.16 MB