Upload 84 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +21 -0
- .python-version +1 -0
- CODE_OF_CONDUCT.md +1 -0
- CONTRIBUTING.md +1 -0
- LICENSE +21 -0
- SECURITY.md +1 -0
- SUPPORT.md +1 -0
- apps/gradio_app.py +33 -0
- apps/gradio_app/__init__.py +0 -0
- apps/gradio_app/aa.py +603 -0
- apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning/1/config.json +11 -0
- apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning/1/result.png +3 -0
- apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning/2/config.json +11 -0
- apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning/2/result.png +3 -0
- apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning/3/config.json +11 -0
- apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning/3/result.png +3 -0
- apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning/4/config.json +11 -0
- apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning/4/result.png +3 -0
- apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-LoRA/1/config.json +13 -0
- apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-LoRA/1/result.png +3 -0
- apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-LoRA/2/config.json +13 -0
- apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-LoRA/2/result.png +3 -0
- apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-LoRA/3/config.json +13 -0
- apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-LoRA/3/result.png +3 -0
- apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-LoRA/4/config.json +13 -0
- apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-LoRA/4/result.png +3 -0
- apps/gradio_app/assets/examples/default_image.png +3 -0
- apps/gradio_app/config_loader.py +5 -0
- apps/gradio_app/example_handler.py +60 -0
- apps/gradio_app/gui_components.py +120 -0
- apps/gradio_app/image_generator.py +54 -0
- apps/gradio_app/old-image_generator.py +77 -0
- apps/gradio_app/project_info.py +36 -0
- apps/gradio_app/setup_scripts.py +64 -0
- apps/gradio_app/static/styles.css +213 -0
- apps/old-gradio_app.py +261 -0
- apps/old2-gradio_app.py +376 -0
- apps/old3-gradio_app.py +438 -0
- apps/old4-gradio_app.py +548 -0
- apps/old5-gradio_app.py +258 -0
- assets/.gitkeep +1 -0
- assets/demo_image.png +3 -0
- assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning/1/config.json +11 -0
- assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning/1/result.png +3 -0
- assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning/2/config.json +11 -0
- assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning/2/result.png +3 -0
- assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning/3/config.json +11 -0
- assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning/3/result.png +3 -0
- assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning/4/config.json +11 -0
- assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning/4/result.png +3 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,24 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
apps/gradio_app/assets/examples/default_image.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning/1/result.png filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning/2/result.png filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning/3/result.png filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning/4/result.png filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-LoRA/1/result.png filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-LoRA/2/result.png filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-LoRA/3/result.png filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-LoRA/4/result.png filter=lfs diff=lfs merge=lfs -text
|
| 45 |
+
assets/demo_image.png filter=lfs diff=lfs merge=lfs -text
|
| 46 |
+
assets/examples/default_image.png filter=lfs diff=lfs merge=lfs -text
|
| 47 |
+
assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning/1/result.png filter=lfs diff=lfs merge=lfs -text
|
| 48 |
+
assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning/2/result.png filter=lfs diff=lfs merge=lfs -text
|
| 49 |
+
assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning/3/result.png filter=lfs diff=lfs merge=lfs -text
|
| 50 |
+
assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning/4/result.png filter=lfs diff=lfs merge=lfs -text
|
| 51 |
+
assets/examples/Ghibli-Stable-Diffusion-2.1-LoRA/1/result.png filter=lfs diff=lfs merge=lfs -text
|
| 52 |
+
assets/examples/Ghibli-Stable-Diffusion-2.1-LoRA/2/result.png filter=lfs diff=lfs merge=lfs -text
|
| 53 |
+
assets/examples/Ghibli-Stable-Diffusion-2.1-LoRA/3/result.png filter=lfs diff=lfs merge=lfs -text
|
| 54 |
+
assets/examples/Ghibli-Stable-Diffusion-2.1-LoRA/4/result.png filter=lfs diff=lfs merge=lfs -text
|
| 55 |
+
tests/test_data/ghibli_style_output_full_finetuning.png filter=lfs diff=lfs merge=lfs -text
|
| 56 |
+
tests/test_data/ghibli_style_output_lora.png filter=lfs diff=lfs merge=lfs -text
|
.python-version
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
3.10.12
|
CODE_OF_CONDUCT.md
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
CONTRIBUTING.md
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2025 Danh Tran
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
SECURITY.md
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
SUPPORT.md
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
apps/gradio_app.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import subprocess
|
| 3 |
+
import os
|
| 4 |
+
import torch
|
| 5 |
+
from gradio_app.gui_components import create_gui
|
| 6 |
+
from gradio_app.config_loader import load_model_configs
|
| 7 |
+
|
| 8 |
+
def run_setup_script():
|
| 9 |
+
setup_script = os.path.join(os.path.dirname(__file__),
|
| 10 |
+
"gradio_app", "setup_scripts.py")
|
| 11 |
+
try:
|
| 12 |
+
result = subprocess.run(["python", setup_script], capture_output=True, text=True, check=True)
|
| 13 |
+
return result.stdout
|
| 14 |
+
except subprocess.CalledProcessError as e:
|
| 15 |
+
print(f"Setup script failed with error: {e.stderr}")
|
| 16 |
+
return f"Setup script failed: {e.stderr}"
|
| 17 |
+
|
| 18 |
+
def main():
|
| 19 |
+
parser = argparse.ArgumentParser(description="Ghibli Stable Diffusion Synthesisr")
|
| 20 |
+
parser.add_argument("--config_path", type=str, default="configs/model_ckpts.yaml")
|
| 21 |
+
parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
|
| 22 |
+
parser.add_argument("--port", type=int, default=7860)
|
| 23 |
+
parser.add_argument("--share", action="store_true")
|
| 24 |
+
args = parser.parse_args()
|
| 25 |
+
print("Running setup script...")
|
| 26 |
+
run_setup_script()
|
| 27 |
+
print("Starting Gradio app...")
|
| 28 |
+
model_configs = load_model_configs(args.config_path)
|
| 29 |
+
demo = create_gui(model_configs, args.device)
|
| 30 |
+
demo.launch(server_port=args.port, share=args.share)
|
| 31 |
+
|
| 32 |
+
if __name__ == "__main__":
|
| 33 |
+
main()
|
apps/gradio_app/__init__.py
ADDED
|
File without changes
|
apps/gradio_app/aa.py
ADDED
|
@@ -0,0 +1,603 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import json
|
| 3 |
+
from typing import Union, List
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
import os
|
| 6 |
+
import gradio as gr
|
| 7 |
+
import torch
|
| 8 |
+
from PIL import Image
|
| 9 |
+
import numpy as np
|
| 10 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
| 11 |
+
from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler, StableDiffusionPipeline
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
import yaml
|
| 14 |
+
|
| 15 |
+
def load_model_configs(config_path: str = "configs/model_ckpts.yaml") -> dict:
|
| 16 |
+
"""
|
| 17 |
+
Load model configurations from a YAML file.
|
| 18 |
+
Returns a dictionary with model IDs and their details.
|
| 19 |
+
"""
|
| 20 |
+
try:
|
| 21 |
+
with open(config_path, 'r') as f:
|
| 22 |
+
configs = yaml.safe_load(f)
|
| 23 |
+
return {cfg['model_id']: cfg for cfg in configs}
|
| 24 |
+
except (IOError, yaml.YAMLError) as e:
|
| 25 |
+
raise ValueError(f"Error loading {config_path}: {e}")
|
| 26 |
+
|
| 27 |
+
def get_examples(examples_dir: Union[str, List[str]] = None,
|
| 28 |
+
use_lora: Union[bool, None] = None) -> List:
|
| 29 |
+
# Convert single string to list
|
| 30 |
+
directories = [examples_dir] if isinstance(examples_dir, str) else examples_dir or []
|
| 31 |
+
|
| 32 |
+
# Validate directories
|
| 33 |
+
valid_dirs = [d for d in directories if os.path.isdir(d)]
|
| 34 |
+
if not valid_dirs:
|
| 35 |
+
print("Error: No valid directories found, using provided examples")
|
| 36 |
+
return get_provided_examples(use_lora)
|
| 37 |
+
|
| 38 |
+
examples = []
|
| 39 |
+
for dir_path in valid_dirs:
|
| 40 |
+
# Get sorted subdirectories
|
| 41 |
+
subdirs = sorted(
|
| 42 |
+
os.path.join(dir_path, d) for d in os.listdir(dir_path) if os.path.isdir(os.path.join(dir_path, d))
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
for subdir in subdirs:
|
| 46 |
+
config_path = os.path.join(subdir, "config.json")
|
| 47 |
+
image_path = os.path.join(subdir, "result.png")
|
| 48 |
+
|
| 49 |
+
if not (os.path.isfile(config_path) and os.path.isfile(image_path)):
|
| 50 |
+
print(f"Error: Missing config.json or result.png in {subdir}")
|
| 51 |
+
continue
|
| 52 |
+
|
| 53 |
+
try:
|
| 54 |
+
with open(config_path, 'r') as f:
|
| 55 |
+
config = json.load(f)
|
| 56 |
+
except (json.JSONDecodeError, IOError) as e:
|
| 57 |
+
print(f"Error reading {config_path}: {e}")
|
| 58 |
+
continue
|
| 59 |
+
|
| 60 |
+
required_keys = ["prompt", "height", "width", "num_inference_steps", "guidance_scale", "seed", "image"]
|
| 61 |
+
if config.get("use_lora", False):
|
| 62 |
+
required_keys.extend(["lora_model_id", "base_model_id", "lora_rank", "lora_scale"])
|
| 63 |
+
else:
|
| 64 |
+
required_keys.append("finetune_model_id")
|
| 65 |
+
|
| 66 |
+
if missing_keys := set(required_keys) - set(config.keys()):
|
| 67 |
+
print(f"Error: Missing keys in {config_path}: {', '.join(missing_keys)}")
|
| 68 |
+
continue
|
| 69 |
+
|
| 70 |
+
if config["image"] != "result.png":
|
| 71 |
+
print(f"Error: Image key in {config_path} does not match 'result.png'")
|
| 72 |
+
continue
|
| 73 |
+
|
| 74 |
+
try:
|
| 75 |
+
Image.open(image_path).verify()
|
| 76 |
+
image = Image.open(image_path) # Re-open after verify
|
| 77 |
+
except Exception as e:
|
| 78 |
+
print(f"Error: Invalid image {image_path}: {e}")
|
| 79 |
+
continue
|
| 80 |
+
|
| 81 |
+
if use_lora is not None and config.get("use_lora", False) != use_lora:
|
| 82 |
+
print(f"DEBUG: Skipping {config_path} due to use_lora mismatch (expected {use_lora}, got {config.get('use_lora', False)})")
|
| 83 |
+
continue
|
| 84 |
+
|
| 85 |
+
# Build example list based on use_lora
|
| 86 |
+
example = [
|
| 87 |
+
config["prompt"],
|
| 88 |
+
config["height"],
|
| 89 |
+
config["width"],
|
| 90 |
+
config["num_inference_steps"],
|
| 91 |
+
config["guidance_scale"],
|
| 92 |
+
config["seed"],
|
| 93 |
+
image,
|
| 94 |
+
# config.get("use_lora", False)
|
| 95 |
+
]
|
| 96 |
+
if config.get("use_lora", False):
|
| 97 |
+
example.extend([
|
| 98 |
+
config["lora_model_id"],
|
| 99 |
+
config["base_model_id"],
|
| 100 |
+
config["lora_rank"],
|
| 101 |
+
config["lora_scale"]
|
| 102 |
+
])
|
| 103 |
+
else:
|
| 104 |
+
example.append(config["finetune_model_id"])
|
| 105 |
+
|
| 106 |
+
examples.append(example)
|
| 107 |
+
print(f"DEBUG: Loaded example from {config_path}: {example[:6]}")
|
| 108 |
+
|
| 109 |
+
return examples or get_provided_examples(use_lora)
|
| 110 |
+
|
| 111 |
+
def get_provided_examples(use_lora: bool = False) -> list:
|
| 112 |
+
example1_image = None
|
| 113 |
+
example2_image = None
|
| 114 |
+
# Attempt to load example images
|
| 115 |
+
if use_lora:
|
| 116 |
+
try:
|
| 117 |
+
example2_path = "apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-LoRA/1/result.png"
|
| 118 |
+
if os.path.exists(example2_path):
|
| 119 |
+
example2_image = Image.open(example2_path)
|
| 120 |
+
except Exception as e:
|
| 121 |
+
print(f"Failed to load example2 image: {e}")
|
| 122 |
+
output = [list({
|
| 123 |
+
"prompt": "a cat is laying on a sofa in Ghibli style",
|
| 124 |
+
"width": 512,
|
| 125 |
+
"height": 768,
|
| 126 |
+
"steps": 100,
|
| 127 |
+
"cfg_scale": 10.0,
|
| 128 |
+
"seed": 789,
|
| 129 |
+
"image": example2_path, # example2_image,
|
| 130 |
+
# "use_lora": True,
|
| 131 |
+
"model": "danhtran2mind/Ghibli-Stable-Diffusion-2.1-LoRA",
|
| 132 |
+
"base_model": "stabilityai/stable-diffusion-2-1",
|
| 133 |
+
"lora_rank": 64,
|
| 134 |
+
"lora_alpha": 0.9
|
| 135 |
+
}.values())]
|
| 136 |
+
|
| 137 |
+
else:
|
| 138 |
+
try:
|
| 139 |
+
example1_path = "apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning/1/result.png"
|
| 140 |
+
if os.path.exists(example1_path):
|
| 141 |
+
example1_image = Image.open(example1_path)
|
| 142 |
+
except Exception as e:
|
| 143 |
+
print(f"Failed to load example1 image: {e}")
|
| 144 |
+
output = [list({
|
| 145 |
+
"prompt": "a serene landscape in Ghibli style",
|
| 146 |
+
"width": 256,
|
| 147 |
+
"height": 512,
|
| 148 |
+
"steps": 50,
|
| 149 |
+
"cfg_scale": 3.5,
|
| 150 |
+
"seed": 42,
|
| 151 |
+
"image": example1_path, # example1_image,
|
| 152 |
+
# "use_lora": False,
|
| 153 |
+
"model": "danhtran2mind/Ghibli-Stable-Diffusion-2.1-Base-finetuning"
|
| 154 |
+
}.values())]
|
| 155 |
+
|
| 156 |
+
return output
|
| 157 |
+
|
| 158 |
+
def create_demo(
|
| 159 |
+
config_path: str = "configs/model_ckpts.yaml",
|
| 160 |
+
device: str = "cuda" if torch.cuda.is_available() else "cpu",
|
| 161 |
+
):
|
| 162 |
+
model_configs = load_model_configs(config_path)
|
| 163 |
+
|
| 164 |
+
finetune_model_id = next((mid for mid, cfg in model_configs.items() if cfg.get('type') == 'full_finetuning'), None)
|
| 165 |
+
lora_model_id = next((mid for mid, cfg in model_configs.items() if cfg.get('type') == 'lora'), None)
|
| 166 |
+
|
| 167 |
+
if not finetune_model_id or not lora_model_id:
|
| 168 |
+
raise ValueError("Could not find full_finetuning or lora model IDs in the configuration file.")
|
| 169 |
+
|
| 170 |
+
finetune_config = model_configs.get(finetune_model_id, {})
|
| 171 |
+
finetune_local_dir = finetune_config.get('local_dir')
|
| 172 |
+
if finetune_local_dir and os.path.exists(finetune_local_dir) and any(os.path.isfile(os.path.join(finetune_local_dir, f)) for f in os.listdir(finetune_local_dir)):
|
| 173 |
+
finetune_model_path = finetune_local_dir
|
| 174 |
+
else:
|
| 175 |
+
finetune_model_path = finetune_model_id
|
| 176 |
+
|
| 177 |
+
lora_config = model_configs.get(lora_model_id, {})
|
| 178 |
+
lora_local_dir = lora_config.get('local_dir')
|
| 179 |
+
if lora_local_dir and os.path.exists(lora_local_dir) and any(os.path.isfile(os.path.join(lora_local_dir, f)) for f in os.listdir(lora_local_dir)):
|
| 180 |
+
lora_model_path = lora_local_dir
|
| 181 |
+
else:
|
| 182 |
+
lora_model_path = lora_model_id
|
| 183 |
+
|
| 184 |
+
base_model_id = lora_config.get('base_model_id', 'stabilityai/stable-diffusion-2-1')
|
| 185 |
+
base_model_config = model_configs.get(base_model_id, {})
|
| 186 |
+
base_local_dir = base_model_config.get('local_dir')
|
| 187 |
+
if base_local_dir and os.path.exists(base_local_dir) and any(os.path.isfile(os.path.join(base_local_dir, f)) for f in os.listdir(base_local_dir)):
|
| 188 |
+
base_model_path = base_local_dir
|
| 189 |
+
else:
|
| 190 |
+
base_model_path = base_model_id
|
| 191 |
+
|
| 192 |
+
device = torch.device(device)
|
| 193 |
+
dtype = torch.float16 if device.type == "cuda" else torch.float32
|
| 194 |
+
|
| 195 |
+
finetune_model_ids = [mid for mid, cfg in model_configs.items() if cfg.get('type') == 'full_finetuning']
|
| 196 |
+
lora_model_ids = [mid for mid, cfg in model_configs.items() if cfg.get('type') == 'lora']
|
| 197 |
+
base_model_ids = [model_configs[mid].get('base_model_id') for mid in model_configs if model_configs[mid].get('base_model_id')]
|
| 198 |
+
|
| 199 |
+
def generate_image(prompt, height, width, num_inference_steps, guidance_scale, seed, random_seed, use_lora, finetune_model_id, lora_model_id, base_model_id, lora_rank, lora_scale):
|
| 200 |
+
try:
|
| 201 |
+
model_configs = load_model_configs(config_path)
|
| 202 |
+
finetune_config = model_configs.get(finetune_model_id, {})
|
| 203 |
+
finetune_local_dir = finetune_config.get('local_dir')
|
| 204 |
+
finetune_model_path = finetune_local_dir if finetune_local_dir and os.path.exists(finetune_local_dir) and any(os.path.isfile(os.path.join(finetune_local_dir, f)) for f in os.listdir(finetune_local_dir)) else finetune_model_id
|
| 205 |
+
|
| 206 |
+
lora_config = model_configs.get(lora_model_id, {})
|
| 207 |
+
lora_local_dir = lora_config.get('local_dir')
|
| 208 |
+
lora_model_path = lora_local_dir if lora_local_dir and os.path.exists(lora_local_dir) and any(os.path.isfile(os.path.join(lora_local_dir, f)) for f in os.listdir(lora_local_dir)) else lora_model_id
|
| 209 |
+
|
| 210 |
+
base_model_config = model_configs.get(base_model_id, {})
|
| 211 |
+
base_local_dir = base_model_config.get('local_dir')
|
| 212 |
+
base_model_path = base_local_dir if base_local_dir and os.path.exists(base_local_dir) and any(os.path.isfile(os.path.join(base_local_dir, f)) for f in os.listdir(base_local_dir)) else base_model_id
|
| 213 |
+
|
| 214 |
+
if not prompt:
|
| 215 |
+
return None, "Prompt cannot be empty."
|
| 216 |
+
if height % 8 != 0 or width % 8 != 0:
|
| 217 |
+
return None, "Height and width must be divisible by 8."
|
| 218 |
+
if num_inference_steps < 1 or num_inference_steps > 100:
|
| 219 |
+
return None, "Number of inference steps must be between 1 and 100."
|
| 220 |
+
if guidance_scale < 1.0 or guidance_scale > 20.0:
|
| 221 |
+
return None, "Guidance scale must be between 1.0 and 20.0."
|
| 222 |
+
if seed < 0 or seed > 4294967295:
|
| 223 |
+
return None, "Seed must be between 0 and 4294967295."
|
| 224 |
+
if use_lora and (not lora_model_path or not os.path.exists(lora_model_path) and not lora_model_path.startswith("danhtran2mind/")):
|
| 225 |
+
return None, f"LoRA model path {lora_model_path} does not exist or is invalid."
|
| 226 |
+
if use_lora and (not base_model_path or not os.path.exists(base_model_path) and not base_model_path.startswith("stabilityai/")):
|
| 227 |
+
return None, f"Base model path {base_model_path} does not exist or is invalid."
|
| 228 |
+
if not use_lora and (not finetune_model_path or not os.path.exists(finetune_model_path) and not finetune_model_path.startswith("danhtran2mind/")):
|
| 229 |
+
return None, f"Fine-tuned model path {finetune_model_path} does not exist or is invalid."
|
| 230 |
+
if use_lora and (lora_rank < 1 or lora_rank > 128):
|
| 231 |
+
return None, "LoRA rank must be between 1 and 128."
|
| 232 |
+
if use_lora and (lora_scale < 0.0 or lora_scale > 2.0):
|
| 233 |
+
return None, "LoRA scale must be between 0.0 and 2.0."
|
| 234 |
+
|
| 235 |
+
batch_size = 1
|
| 236 |
+
if random_seed:
|
| 237 |
+
seed = torch.randint(0, 4294967295, (1,)).item()
|
| 238 |
+
generator = torch.Generator(device=device).manual_seed(int(seed))
|
| 239 |
+
|
| 240 |
+
if use_lora:
|
| 241 |
+
try:
|
| 242 |
+
pipe = StableDiffusionPipeline.from_pretrained(
|
| 243 |
+
base_model_path, torch_dtype=dtype, use_safetensors=True
|
| 244 |
+
)
|
| 245 |
+
pipe.load_lora_weights(lora_model_path, adapter_name="ghibli-lora", lora_scale=lora_scale)
|
| 246 |
+
pipe = pipe.to(device)
|
| 247 |
+
vae = pipe.vae
|
| 248 |
+
tokenizer = pipe.tokenizer
|
| 249 |
+
text_encoder = pipe.text_encoder
|
| 250 |
+
unet = pipe.unet
|
| 251 |
+
scheduler = PNDMScheduler.from_config(pipe.scheduler.config)
|
| 252 |
+
except Exception as e:
|
| 253 |
+
return None, f"Error loading LoRA model: {e}"
|
| 254 |
+
else:
|
| 255 |
+
try:
|
| 256 |
+
vae = AutoencoderKL.from_pretrained(finetune_model_path, subfolder="vae", torch_dtype=dtype).to(device)
|
| 257 |
+
tokenizer = CLIPTokenizer.from_pretrained(finetune_model_path, subfolder="tokenizer")
|
| 258 |
+
text_encoder = CLIPTextModel.from_pretrained(finetune_model_path, subfolder="text_encoder", torch_dtype=dtype).to(device)
|
| 259 |
+
unet = UNet2DConditionModel.from_pretrained(finetune_model_path, subfolder="unet", torch_dtype=dtype).to(device)
|
| 260 |
+
scheduler = PNDMScheduler.from_pretrained(finetune_model_path, subfolder="scheduler")
|
| 261 |
+
except Exception as e:
|
| 262 |
+
return None, f"Error loading fine-tuned model: {e}"
|
| 263 |
+
|
| 264 |
+
text_input = tokenizer(
|
| 265 |
+
[prompt], padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt"
|
| 266 |
+
)
|
| 267 |
+
with torch.no_grad():
|
| 268 |
+
text_embeddings = text_encoder(text_input.input_ids.to(device))[0].to(dtype=dtype)
|
| 269 |
+
|
| 270 |
+
max_length = text_input.input_ids.shape[-1]
|
| 271 |
+
uncond_input = tokenizer(
|
| 272 |
+
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
|
| 273 |
+
)
|
| 274 |
+
with torch.no_grad():
|
| 275 |
+
uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0].to(dtype=dtype)
|
| 276 |
+
|
| 277 |
+
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
| 278 |
+
|
| 279 |
+
latents = torch.randn(
|
| 280 |
+
(batch_size, unet.config.in_channels, height // 8, width // 8),
|
| 281 |
+
generator=generator, dtype=dtype, device=device
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
scheduler.set_timesteps(num_inference_steps)
|
| 285 |
+
latents = latents * scheduler.init_noise_sigma
|
| 286 |
+
|
| 287 |
+
for t in tqdm(scheduler.timesteps, desc="Generating image"):
|
| 288 |
+
latent_model_input = torch.cat([latents] * 2)
|
| 289 |
+
latent_model_input = scheduler.scale_model_input(latent_model_input, t)
|
| 290 |
+
|
| 291 |
+
with torch.no_grad():
|
| 292 |
+
if device.type == "cuda":
|
| 293 |
+
with torch.autocast(device_type="cuda", dtype=torch.float16):
|
| 294 |
+
noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
| 295 |
+
else:
|
| 296 |
+
noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
| 297 |
+
|
| 298 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 299 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 300 |
+
latents = scheduler.step(noise_pred, t, latents).prev_sample
|
| 301 |
+
|
| 302 |
+
with torch.no_grad():
|
| 303 |
+
latents = latents / vae.config.scaling_factor
|
| 304 |
+
image = vae.decode(latents).sample
|
| 305 |
+
|
| 306 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
| 307 |
+
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
|
| 308 |
+
image = (image * 255).round().astype("uint8")
|
| 309 |
+
pil_image = Image.fromarray(image[0])
|
| 310 |
+
|
| 311 |
+
if use_lora:
|
| 312 |
+
del pipe
|
| 313 |
+
else:
|
| 314 |
+
del vae, tokenizer, text_encoder, unet, scheduler
|
| 315 |
+
torch.cuda.empty_cache()
|
| 316 |
+
|
| 317 |
+
return pil_image, f"Generated image successfully! Seed used: {seed}"
|
| 318 |
+
except Exception as e:
|
| 319 |
+
return None, f"Failed to generate image: {e}"
|
| 320 |
+
|
| 321 |
+
def load_example_image_full_finetuning(prompt, height, width, num_inference_steps, guidance_scale,
|
| 322 |
+
seed, image, finetune_model_id):
|
| 323 |
+
try:
|
| 324 |
+
status = "Loaded example successfully"
|
| 325 |
+
return (
|
| 326 |
+
prompt, height, width, num_inference_steps, guidance_scale, seed,
|
| 327 |
+
image, finetune_model_id, status
|
| 328 |
+
)
|
| 329 |
+
except Exception as e:
|
| 330 |
+
print(f"DEBUG: Exception in load_example_image: {e}")
|
| 331 |
+
return (
|
| 332 |
+
prompt, height, width, num_inference_steps, guidance_scale, seed,
|
| 333 |
+
None, finetune_model_id,
|
| 334 |
+
f"Error loading example: {e}"
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
def load_example_image_lora(prompt, height, width, num_inference_steps, guidance_scale,
|
| 338 |
+
seed, image, lora_model_id,
|
| 339 |
+
base_model_id, lora_rank, lora_scale):
|
| 340 |
+
try:
|
| 341 |
+
status = "Loaded example successfully"
|
| 342 |
+
# Ensure base_model_id, lora_rank, and lora_scale have valid values
|
| 343 |
+
base_model_id = base_model_id or "stabilityai/stable-diffusion-2-1"
|
| 344 |
+
lora_rank = lora_rank if lora_rank is not None else 64
|
| 345 |
+
lora_scale = lora_scale if lora_scale is not None else 1.2
|
| 346 |
+
|
| 347 |
+
return (
|
| 348 |
+
prompt, height, width, num_inference_steps, guidance_scale, seed,
|
| 349 |
+
image, lora_model_id, base_model_id,
|
| 350 |
+
lora_rank, lora_scale, status
|
| 351 |
+
)
|
| 352 |
+
except Exception as e:
|
| 353 |
+
print(f"DEBUG: Exception in load_example_image_lora: {e}")
|
| 354 |
+
return (
|
| 355 |
+
prompt, height, width, num_inference_steps, guidance_scale, seed,
|
| 356 |
+
None, lora_model_id, base_model_id or "stabilityai/stable-diffusion-2-1",
|
| 357 |
+
lora_rank or 64, lora_scale or 1.2, f"Error loading example: {e}"
|
| 358 |
+
)
|
| 359 |
+
|
| 360 |
+
badges_text = r"""
|
| 361 |
+
<div style="text-align: left; font-size: 14px; display: flex; flex-direction: column; gap: 10px;">
|
| 362 |
+
<div style="display: flex; align-items: center; justify-content: left; gap: 8px;">
|
| 363 |
+
You can explore GitHub repository:
|
| 364 |
+
<a href="https://github.com/danhtran2mind/Ghibli-Stable-Diffusion-Synthesis">
|
| 365 |
+
<img src="https://img.shields.io/badge/GitHub-danhtran2mind%2FGhibli--Stable--Diffusion--Synthesis-blue?style=flat&logo=github" alt="GitHub Repo">
|
| 366 |
+
</a>. And you can explore HuggingFace Model Hub:
|
| 367 |
+
<a href="https://huggingface.co/spaces/danhtran2mind/Ghibli-Stable-Diffusion-2.1-Base-finetuning">
|
| 368 |
+
<img src="https://img.shields.io/badge/HuggingFace-danhtran2mind%2FGhibli--Stable--Diffusion--2.1--Base--finetuning-yellow?style=flat&logo=huggingface" alt="HuggingFace Space Demo">
|
| 369 |
+
</a>
|
| 370 |
+
and
|
| 371 |
+
<a href="https://huggingface.co/spaces/danhtran2mind/Ghibli-Stable-Diffusion-2.1-LoRA">
|
| 372 |
+
<img src="https://img.shields.io/badge/HuggingFace-danhtran2mind%2FGhibli--Stable--Diffusion--2.1--LoRA-yellow?style=flat&logo=huggingface" alt="HuggingFace Space Demo">
|
| 373 |
+
</a>
|
| 374 |
+
</div>
|
| 375 |
+
</div>
|
| 376 |
+
""".strip()
|
| 377 |
+
|
| 378 |
+
try:
|
| 379 |
+
custom_css = open("apps/gradio_app/static/styles.css", "r").read()
|
| 380 |
+
except FileNotFoundError:
|
| 381 |
+
print("Error: styles.css not found, using default styling")
|
| 382 |
+
custom_css = ""
|
| 383 |
+
|
| 384 |
+
examples_full_finetuning = get_examples("apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning",
|
| 385 |
+
use_lora=False)
|
| 386 |
+
examples_lora = get_examples("apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-LoRA",
|
| 387 |
+
use_lora=True)
|
| 388 |
+
|
| 389 |
+
with gr.Blocks(css=custom_css, theme="ocean") as demo:
|
| 390 |
+
gr.Markdown("## Ghibli-Style Image Generator")
|
| 391 |
+
with gr.Tabs():
|
| 392 |
+
with gr.Tab(label="Full Finetuning"):
|
| 393 |
+
with gr.Row():
|
| 394 |
+
with gr.Column(scale.=1):
|
| 395 |
+
gr.Markdown("### Image Generation Settings")
|
| 396 |
+
prompt_ft = gr.Textbox(
|
| 397 |
+
label="Prompt",
|
| 398 |
+
placeholder="e.g., 'a serene landscape in Ghibli style'",
|
| 399 |
+
lines=2
|
| 400 |
+
)
|
| 401 |
+
with gr.Group():
|
| 402 |
+
gr.Markdown("#### Image Dimensions")
|
| 403 |
+
with gr.Row():
|
| 404 |
+
width_ft = gr.Slider(
|
| 405 |
+
minimum=32, maximum=4096, value=512, step=8, label="Width"
|
| 406 |
+
)
|
| 407 |
+
height_ft = gr.Slider(
|
| 408 |
+
minimum=32, maximum=4096, value=512, step=8, label="Height"
|
| 409 |
+
)
|
| 410 |
+
with gr.Accordion("Advanced Settings", open=False):
|
| 411 |
+
num_inference_steps_ft = gr.Slider(
|
| 412 |
+
minimum=1, maximum=100, value=50, step=1, label="Inference Steps",
|
| 413 |
+
info="More steps, better quality, longer wait."
|
| 414 |
+
)
|
| 415 |
+
guidance_scale_ft = gr.Slider(
|
| 416 |
+
minimum=1.0, maximum=20.0, value=3.5, step=0.5, label="Guidance Scale",
|
| 417 |
+
info="Controls how closely the image follows the prompt."
|
| 418 |
+
)
|
| 419 |
+
random_seed_ft = gr.Checkbox(label="Use Random Seed", value=False)
|
| 420 |
+
seed_ft = gr.Slider(
|
| 421 |
+
minimum=0, maximum=4294967295, value=42, step=1,
|
| 422 |
+
label="Seed", info="Use a seed (0-4294967295) for consistent results."
|
| 423 |
+
)
|
| 424 |
+
with gr.Group():
|
| 425 |
+
gr.Markdown("#### Model Configuration")
|
| 426 |
+
finetune_model_path_ft = gr.Dropdown(
|
| 427 |
+
label="Fine-tuned Model", choices=finetune_model_ids,
|
| 428 |
+
value=finetune_model_id
|
| 429 |
+
)
|
| 430 |
+
# image_path_ft = gr.Textbox(visible=False)
|
| 431 |
+
|
| 432 |
+
with gr.Column(scale=1):
|
| 433 |
+
gr.Markdown("### Generated Result")
|
| 434 |
+
output_image_ft = gr.Image(label="Generated Image", interactive=False, height=512)
|
| 435 |
+
output_text_ft = gr.Textbox(label="Status", interactive=False, lines=3)
|
| 436 |
+
|
| 437 |
+
generate_btn_ft = gr.Button("Generate Image", variant="primary")
|
| 438 |
+
stop_btn_ft = gr.Button("Stop Generation")
|
| 439 |
+
|
| 440 |
+
gr.Markdown("### Examples for Full Finetuning")
|
| 441 |
+
gr.Examples(
|
| 442 |
+
examples=examples_full_finetuning,
|
| 443 |
+
inputs=[
|
| 444 |
+
prompt_ft, height_ft, width_ft, num_inference_steps_ft,
|
| 445 |
+
guidance_scale_ft, seed_ft, output_image_ft, finetune_model_path_ft
|
| 446 |
+
],
|
| 447 |
+
outputs=[prompt_ft, height_ft, width_ft, num_inference_steps_ft,
|
| 448 |
+
guidance_scale_ft, seed_ft, output_image_ft, finetune_model_path_ft,
|
| 449 |
+
output_text_ft],
|
| 450 |
+
fn=load_example_image_full_finetuning,
|
| 451 |
+
# fn=lambda *args: load_example_image_full_finetuning(*args),
|
| 452 |
+
cache_examples=False,
|
| 453 |
+
label="Examples for Full Fine-tuning",
|
| 454 |
+
examples_per_page=4
|
| 455 |
+
)
|
| 456 |
+
|
| 457 |
+
with gr.Tab(label="LoRA"):
|
| 458 |
+
with gr.Row():
|
| 459 |
+
with gr.Column(scale=1):
|
| 460 |
+
gr.Markdown("### Image Generation Settings")
|
| 461 |
+
prompt_lora = gr.Textbox(
|
| 462 |
+
label="Prompt",
|
| 463 |
+
placeholder="e.g., 'a serene landscape in Ghibli style'",
|
| 464 |
+
lines=2
|
| 465 |
+
)
|
| 466 |
+
with gr.Group():
|
| 467 |
+
gr.Markdown("#### Image Dimensions")
|
| 468 |
+
with gr.Row():
|
| 469 |
+
width_lora = gr.Slider(
|
| 470 |
+
minimum=32, maximum=4096, value=512, step=8, label="Width"
|
| 471 |
+
)
|
| 472 |
+
height_lora = gr.Slider(
|
| 473 |
+
minimum=32, maximum=4096, value=512, step=8, label="Height"
|
| 474 |
+
)
|
| 475 |
+
with gr.Accordion("Advanced Settings", open=False):
|
| 476 |
+
num_inference_steps_lora = gr.Slider(
|
| 477 |
+
minimum=1, maximum=100, value=50, step=1, label="Inference Steps",
|
| 478 |
+
info="More steps, better quality, longer wait."
|
| 479 |
+
)
|
| 480 |
+
guidance_scale_lora = gr.Slider(
|
| 481 |
+
minimum=1.0, maximum=20.0, value=3.5, step=0.5, label="Guidance Scale",
|
| 482 |
+
info="Controls how closely the image follows the prompt."
|
| 483 |
+
)
|
| 484 |
+
lora_rank_lora = gr.Slider(
|
| 485 |
+
minimum=1, maximum=128, value=64, step=1, label="LoRA Rank",
|
| 486 |
+
info="Controls model complexity and memory usage."
|
| 487 |
+
)
|
| 488 |
+
lora_scale_lora = gr.Slider(
|
| 489 |
+
minimum=0.0, maximum=2.0, value=1.2, step=0.1, label="LoRA Scale",
|
| 490 |
+
info="Adjusts the influence of LoRA weights."
|
| 491 |
+
)
|
| 492 |
+
random_seed_lora = gr.Checkbox(label="Use Random Seed", value=False)
|
| 493 |
+
seed_lora = gr.Slider(
|
| 494 |
+
minimum=0, maximum=4294967295, value=42, step=1,
|
| 495 |
+
label="Seed", info="Use a seed (0-4294967295) for consistent results."
|
| 496 |
+
)
|
| 497 |
+
with gr.Group():
|
| 498 |
+
gr.Markdown("#### Model Configuration")
|
| 499 |
+
lora_model_path_lora = gr.Dropdown(
|
| 500 |
+
label="LoRA Model", choices=lora_model_ids,
|
| 501 |
+
value=lora_model_id
|
| 502 |
+
)
|
| 503 |
+
base_model_path_lora = gr.Dropdown(
|
| 504 |
+
label="Base Model", choices=base_model_ids,
|
| 505 |
+
value=base_model_id
|
| 506 |
+
)
|
| 507 |
+
# image_path_lora = gr.Textbox(visible=False)
|
| 508 |
+
|
| 509 |
+
with gr.Column(scale=1):
|
| 510 |
+
gr.Markdown("### Generated Result")
|
| 511 |
+
output_image_lora = gr.Image(label="Generated Image", interactive=False, height=512)
|
| 512 |
+
output_text_lora = gr.Textbox(label="Status", interactive=False, lines=3)
|
| 513 |
+
|
| 514 |
+
generate_btn_lora = gr.Button("Generate Image", variant="primary")
|
| 515 |
+
stop_btn_lora = gr.Button("Stop Generation")
|
| 516 |
+
|
| 517 |
+
gr.Markdown("### Examples for LoRA")
|
| 518 |
+
gr.Examples(
|
| 519 |
+
examples=examples_lora,
|
| 520 |
+
inputs=[
|
| 521 |
+
prompt_lora, height_lora, width_lora, num_inference_steps_lora,
|
| 522 |
+
guidance_scale_lora, seed_lora, output_image_lora,
|
| 523 |
+
lora_model_path_lora, base_model_path_lora,
|
| 524 |
+
lora_rank_lora, lora_scale_lora
|
| 525 |
+
],
|
| 526 |
+
outputs=[
|
| 527 |
+
prompt_lora, height_lora, width_lora, num_inference_steps_lora,
|
| 528 |
+
guidance_scale_lora, seed_lora, output_image_lora,
|
| 529 |
+
lora_model_path_lora, base_model_path_lora,
|
| 530 |
+
lora_rank_lora, lora_scale_lora,
|
| 531 |
+
output_text_lora
|
| 532 |
+
],
|
| 533 |
+
fn=load_example_image_lora,
|
| 534 |
+
# fn=lambda *args: load_example_image_lora(*args),
|
| 535 |
+
cache_examples=False,
|
| 536 |
+
label="Examples for LoRA",
|
| 537 |
+
examples_per_page=4
|
| 538 |
+
)
|
| 539 |
+
|
| 540 |
+
gr.Markdown(badges_text)
|
| 541 |
+
|
| 542 |
+
generate_event_ft = generate_btn_ft.click(
|
| 543 |
+
fn=generate_image,
|
| 544 |
+
inputs=[
|
| 545 |
+
prompt_ft, height_ft, width_ft, num_inference_steps_ft, guidance_scale_ft, seed_ft,
|
| 546 |
+
random_seed_ft, gr.State(value=False), finetune_model_path_ft, gr.State(value=None),
|
| 547 |
+
gr.State(value=None), gr.State(value=None), gr.State(value=None)
|
| 548 |
+
],
|
| 549 |
+
outputs=[output_image_ft, output_text_ft]
|
| 550 |
+
)
|
| 551 |
+
|
| 552 |
+
generate_event_lora = generate_btn_lora.click(
|
| 553 |
+
fn=generate_image,
|
| 554 |
+
inputs=[
|
| 555 |
+
prompt_lora, height_lora, width_lora, num_inference_steps_lora, guidance_scale_lora, seed_lora,
|
| 556 |
+
random_seed_lora, gr.State(value=True), gr.State(value=None), lora_model_path_lora,
|
| 557 |
+
base_model_path_lora, lora_rank_lora, lora_scale_lora
|
| 558 |
+
],
|
| 559 |
+
outputs=[output_image_lora, output_text_lora]
|
| 560 |
+
)
|
| 561 |
+
|
| 562 |
+
stop_btn_ft.click(fn=None, inputs=None, outputs=None, cancels=[generate_event_ft])
|
| 563 |
+
stop_btn_lora.click(fn=None, inputs=None, outputs=None, cancels=[generate_event_lora])
|
| 564 |
+
|
| 565 |
+
def cleanup():
|
| 566 |
+
print("DEBUG: Cleaning up resources...")
|
| 567 |
+
torch.cuda.empty_cache()
|
| 568 |
+
|
| 569 |
+
demo.unload(cleanup)
|
| 570 |
+
|
| 571 |
+
return demo
|
| 572 |
+
|
| 573 |
+
if __name__ == "__main__":
|
| 574 |
+
parser = argparse.ArgumentParser(description="Ghibli-Style Image Generator using a fine-tuned Stable Diffusion model or Stable Diffusion 2.1 with LoRA weights.")
|
| 575 |
+
parser.add_argument(
|
| 576 |
+
"--config_path",
|
| 577 |
+
type=str,
|
| 578 |
+
default="configs/model_ckpts.yaml",
|
| 579 |
+
help="Path to the model configuration YAML file."
|
| 580 |
+
)
|
| 581 |
+
parser.add_argument(
|
| 582 |
+
"--device",
|
| 583 |
+
type=str,
|
| 584 |
+
default="cuda" if torch.cuda.is_available() else "cpu",
|
| 585 |
+
help="Device to run the model on (e.g., 'cuda', 'cpu')."
|
| 586 |
+
)
|
| 587 |
+
parser.add_argument(
|
| 588 |
+
"--port",
|
| 589 |
+
type=int,
|
| 590 |
+
default=7860,
|
| 591 |
+
help="Port to run the Gradio app on."
|
| 592 |
+
)
|
| 593 |
+
parser.add_argument(
|
| 594 |
+
"--share",
|
| 595 |
+
action="store_true",
|
| 596 |
+
default=False,
|
| 597 |
+
help="Set to True for public sharing (Hugging Face Spaces)."
|
| 598 |
+
)
|
| 599 |
+
|
| 600 |
+
args = parser.parse_args()
|
| 601 |
+
|
| 602 |
+
demo = create_demo(args.config_path, args.device)
|
| 603 |
+
demo.launch(server_port=args.port, share=args.share)
|
apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning/1/config.json
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"prompt": "a serene landscape in Ghibli style",
|
| 3 |
+
"height": 256,
|
| 4 |
+
"width": 512,
|
| 5 |
+
"num_inference_steps": 50,
|
| 6 |
+
"guidance_scale": 3.5,
|
| 7 |
+
"seed": 42,
|
| 8 |
+
"image": "result.png",
|
| 9 |
+
"use_lora": false,
|
| 10 |
+
"finetune_model_id": "danhtran2mind/Ghibli-Stable-Diffusion-2.1-Base-finetuning"
|
| 11 |
+
}
|
apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning/1/result.png
ADDED
|
Git LFS Details
|
apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning/2/config.json
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"prompt": "Donald Trump",
|
| 3 |
+
"height": 512,
|
| 4 |
+
"width": 512,
|
| 5 |
+
"num_inference_steps": 100,
|
| 6 |
+
"guidance_scale": 9,
|
| 7 |
+
"seed": 200,
|
| 8 |
+
"image": "result.png",
|
| 9 |
+
"use_lora": false,
|
| 10 |
+
"finetune_model_id": "danhtran2mind/Ghibli-Stable-Diffusion-2.1-Base-finetuning"
|
| 11 |
+
}
|
apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning/2/result.png
ADDED
|
Git LFS Details
|
apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning/3/config.json
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"prompt": "a dancer in Ghibli style",
|
| 3 |
+
"height": 384,
|
| 4 |
+
"width": 192,
|
| 5 |
+
"num_inference_steps": 50,
|
| 6 |
+
"guidance_scale": 15.5,
|
| 7 |
+
"seed": 4223,
|
| 8 |
+
"image": "result.png",
|
| 9 |
+
"use_lora": false,
|
| 10 |
+
"finetune_model_id": "danhtran2mind/Ghibli-Stable-Diffusion-2.1-Base-finetuning"
|
| 11 |
+
}
|
apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning/3/result.png
ADDED
|
Git LFS Details
|
apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning/4/config.json
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"prompt": "Ghibli style, the peace beach",
|
| 3 |
+
"height": 1024,
|
| 4 |
+
"width": 2048,
|
| 5 |
+
"num_inference_steps": 100,
|
| 6 |
+
"guidance_scale": 7.5,
|
| 7 |
+
"seed": 5678,
|
| 8 |
+
"image": "result.png",
|
| 9 |
+
"use_lora": false,
|
| 10 |
+
"finetune_model_id": "danhtran2mind/Ghibli-Stable-Diffusion-2.1-Base-finetuning"
|
| 11 |
+
}
|
apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning/4/result.png
ADDED
|
Git LFS Details
|
apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-LoRA/1/config.json
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"prompt": "a cat is laying on a sofa in Ghibli style",
|
| 3 |
+
"height": 512,
|
| 4 |
+
"width": 768,
|
| 5 |
+
"num_inference_steps": 100,
|
| 6 |
+
"guidance_scale": 10,
|
| 7 |
+
"seed": 789,
|
| 8 |
+
"image": "result.png",
|
| 9 |
+
"use_lora": true,
|
| 10 |
+
"lora_model_id": "danhtran2mind/Ghibli-Stable-Diffusion-2.1-LoRA",
|
| 11 |
+
"base_model_id": "stabilityai/stable-diffusion-2-1",
|
| 12 |
+
"lora_scale": 0.9
|
| 13 |
+
}
|
apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-LoRA/1/result.png
ADDED
|
Git LFS Details
|
apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-LoRA/2/config.json
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"prompt": "Ghibli style, a majestic mountain towers, casting shadows on the serene beach.",
|
| 3 |
+
"height": 1024,
|
| 4 |
+
"width": 2048,
|
| 5 |
+
"num_inference_steps": 75,
|
| 6 |
+
"guidance_scale": 14.5,
|
| 7 |
+
"seed": 9999,
|
| 8 |
+
"image": "result.png",
|
| 9 |
+
"use_lora": true,
|
| 10 |
+
"lora_model_id": "danhtran2mind/Ghibli-Stable-Diffusion-2.1-LoRA",
|
| 11 |
+
"base_model_id": "stabilityai/stable-diffusion-2-1",
|
| 12 |
+
"lora_scale": 1
|
| 13 |
+
}
|
apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-LoRA/2/result.png
ADDED
|
Git LFS Details
|
apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-LoRA/3/config.json
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"prompt": "In a soft, Ghibli style, Elon Musk is in a suit.",
|
| 3 |
+
"height": 512,
|
| 4 |
+
"width": 512,
|
| 5 |
+
"num_inference_steps": 82,
|
| 6 |
+
"guidance_scale": 18,
|
| 7 |
+
"seed": 1,
|
| 8 |
+
"image": "result.png",
|
| 9 |
+
"use_lora": true,
|
| 10 |
+
"lora_model_id": "danhtran2mind/Ghibli-Stable-Diffusion-2.1-LoRA",
|
| 11 |
+
"base_model_id": "stabilityai/stable-diffusion-2-1",
|
| 12 |
+
"lora_scale": 1.4
|
| 13 |
+
}
|
apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-LoRA/3/result.png
ADDED
|
Git LFS Details
|
apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-LoRA/4/config.json
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"prompt": "In a Ghibli-esque world, A close-up shows a race car's soft, sun-drenched, whimsical details.",
|
| 3 |
+
"height": 1024,
|
| 4 |
+
"width": 1024,
|
| 5 |
+
"num_inference_steps": 42,
|
| 6 |
+
"guidance_scale": 20,
|
| 7 |
+
"seed": 1589,
|
| 8 |
+
"image": "result.png",
|
| 9 |
+
"use_lora": true,
|
| 10 |
+
"lora_model_id": "danhtran2mind/Ghibli-Stable-Diffusion-2.1-LoRA",
|
| 11 |
+
"base_model_id": "stabilityai/stable-diffusion-2-1",
|
| 12 |
+
"lora_scale": 0.7
|
| 13 |
+
}
|
apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-LoRA/4/result.png
ADDED
|
Git LFS Details
|
apps/gradio_app/assets/examples/default_image.png
ADDED
|
Git LFS Details
|
apps/gradio_app/config_loader.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import yaml
|
| 2 |
+
|
| 3 |
+
def load_model_configs(config_path: str = "configs/model_ckpts.yaml") -> dict:
|
| 4 |
+
with open(config_path, 'r') as f:
|
| 5 |
+
return {cfg['model_id']: cfg for cfg in yaml.safe_load(f)}
|
apps/gradio_app/example_handler.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
from typing import Union, List
|
| 4 |
+
from PIL import Image
|
| 5 |
+
|
| 6 |
+
def get_examples(examples_dir: Union[str, List[str]] = None, use_lora: bool = None) -> List:
|
| 7 |
+
directories = [examples_dir] if isinstance(examples_dir, str) else examples_dir or []
|
| 8 |
+
valid_dirs = [d for d in directories if os.path.isdir(d)]
|
| 9 |
+
if not valid_dirs:
|
| 10 |
+
return get_provided_examples(use_lora)
|
| 11 |
+
|
| 12 |
+
examples = []
|
| 13 |
+
for dir_path in valid_dirs:
|
| 14 |
+
for subdir in sorted(os.path.join(dir_path, d) for d in os.listdir(dir_path) if os.path.isdir(os.path.join(dir_path, d))):
|
| 15 |
+
config_path = os.path.join(subdir, "config.json")
|
| 16 |
+
image_path = os.path.join(subdir, "result.png")
|
| 17 |
+
if not (os.path.isfile(config_path) and os.path.isfile(image_path)):
|
| 18 |
+
continue
|
| 19 |
+
|
| 20 |
+
with open(config_path, 'r') as f:
|
| 21 |
+
config = json.load(f)
|
| 22 |
+
|
| 23 |
+
required_keys = ["prompt", "height", "width", "num_inference_steps", "guidance_scale", "seed", "image"]
|
| 24 |
+
if config.get("use_lora", False):
|
| 25 |
+
required_keys.extend(["lora_model_id", "base_model_id",
|
| 26 |
+
# "lora_rank",
|
| 27 |
+
"lora_scale"])
|
| 28 |
+
else:
|
| 29 |
+
required_keys.append("finetune_model_id")
|
| 30 |
+
|
| 31 |
+
if set(required_keys) - set(config.keys()) or config["image"] != "result.png":
|
| 32 |
+
continue
|
| 33 |
+
|
| 34 |
+
try:
|
| 35 |
+
image = Image.open(image_path)
|
| 36 |
+
except Exception:
|
| 37 |
+
continue
|
| 38 |
+
|
| 39 |
+
if use_lora is not None and config.get("use_lora", False) != use_lora:
|
| 40 |
+
continue
|
| 41 |
+
|
| 42 |
+
example = [config["prompt"], config["height"], config["width"], config["num_inference_steps"],
|
| 43 |
+
config["guidance_scale"], config["seed"], image]
|
| 44 |
+
example.extend([config["lora_model_id"], config["base_model_id"],
|
| 45 |
+
# config["lora_rank"],
|
| 46 |
+
config["lora_scale"]]
|
| 47 |
+
if config.get("use_lora", False) else [config["finetune_model_id"]])
|
| 48 |
+
examples.append(example)
|
| 49 |
+
|
| 50 |
+
return examples or get_provided_examples(use_lora)
|
| 51 |
+
|
| 52 |
+
def get_provided_examples(use_lora: bool = False) -> list:
|
| 53 |
+
example_path = f"apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-{'LoRA' if use_lora else 'Base-finetuning'}/1/result.png"
|
| 54 |
+
image = Image.open(example_path) if os.path.exists(example_path) else None
|
| 55 |
+
return [[
|
| 56 |
+
"a cat is laying on a sofa in Ghibli style" if use_lora else "a serene landscape in Ghibli style",
|
| 57 |
+
512, 768 if use_lora else 512, 100 if use_lora else 50, 10.0 if use_lora else 3.5, 789 if use_lora else 42,
|
| 58 |
+
image, "danhtran2mind/Ghibli-Stable-Diffusion-2.1-LoRA" if use_lora else "danhtran2mind/Ghibli-Stable-Diffusion-2.1-Base-finetuning",
|
| 59 |
+
"stabilityai/stable-diffusion-2-1" if use_lora else None, 64 if use_lora else None, 0.9 if use_lora else None
|
| 60 |
+
]]
|
apps/gradio_app/gui_components.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import torch
|
| 3 |
+
import os
|
| 4 |
+
from .example_handler import get_examples
|
| 5 |
+
from .image_generator import generate_image
|
| 6 |
+
from .project_info import intro_markdown_1, intro_markdown_2, outro_markdown_1
|
| 7 |
+
|
| 8 |
+
def load_example_image_full_finetuning(prompt, height, width, num_inference_steps, guidance_scale, seed, image, finetune_model_id):
|
| 9 |
+
return prompt, height, width, num_inference_steps, guidance_scale, seed, image, finetune_model_id, "Loaded example successfully"
|
| 10 |
+
|
| 11 |
+
def load_example_image_lora(prompt, height, width, num_inference_steps, guidance_scale, seed, image, lora_model_id, base_model_id, lora_scale):
|
| 12 |
+
return prompt, height, width, num_inference_steps, guidance_scale, seed, image, lora_model_id, base_model_id or "stabilityai/stable-diffusion-2-1", lora_scale or 1.2, "Loaded example successfully"
|
| 13 |
+
|
| 14 |
+
def create_gui(model_configs, device):
|
| 15 |
+
finetune_model_id = next((mid for mid, cfg in model_configs.items() if cfg.get('type') == 'full_finetuning'), None)
|
| 16 |
+
lora_model_id = next((mid for mid, cfg in model_configs.items() if cfg.get('type') == 'lora'), None)
|
| 17 |
+
|
| 18 |
+
if not finetune_model_id or not lora_model_id:
|
| 19 |
+
raise ValueError("Missing model IDs in config.")
|
| 20 |
+
|
| 21 |
+
base_model_id = model_configs[lora_model_id].get('base_model_id', 'stabilityai/stable-diffusion-2-1')
|
| 22 |
+
device = torch.device(device)
|
| 23 |
+
dtype = torch.float16 if device.type == "cuda" else torch.float32
|
| 24 |
+
config_path = "configs/model_ckpts.yaml"
|
| 25 |
+
|
| 26 |
+
custom_css = open("apps/gradio_app/static/styles.css", "r").read() if os.path.exists("apps/gradio_app/static/styles.css") else ""
|
| 27 |
+
|
| 28 |
+
examples_full_finetuning = get_examples("apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning", use_lora=False)
|
| 29 |
+
examples_lora = get_examples("apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-LoRA", use_lora=True)
|
| 30 |
+
|
| 31 |
+
with gr.Blocks(css=custom_css, theme="ocean") as demo:
|
| 32 |
+
gr.Markdown("# Ghibli Stable Diffusion Synthesis")
|
| 33 |
+
gr.HTML(intro_markdown_1)
|
| 34 |
+
gr.HTML(intro_markdown_2)
|
| 35 |
+
with gr.Tabs():
|
| 36 |
+
with gr.Tab(label="Full Finetuning"):
|
| 37 |
+
with gr.Row():
|
| 38 |
+
with gr.Column(scale=1):
|
| 39 |
+
gr.Markdown("### Image Generation Settings")
|
| 40 |
+
prompt_ft = gr.Textbox(label="Prompt", placeholder="e.g., 'a serene landscape in Ghibli style'", lines=2)
|
| 41 |
+
with gr.Group():
|
| 42 |
+
gr.Markdown("#### Image Dimensions")
|
| 43 |
+
with gr.Row():
|
| 44 |
+
height_ft = gr.Slider(32, 4096, 512, step=8, label="Height")
|
| 45 |
+
width_ft = gr.Slider(32, 4096, 512, step=8, label="Width")
|
| 46 |
+
with gr.Accordion("Advanced Settings", open=False):
|
| 47 |
+
num_inference_steps_ft = gr.Slider(1, 100, 50, step=1, label="Inference Steps")
|
| 48 |
+
guidance_scale_ft = gr.Slider(1.0, 20.0, 3.5, step=0.5, label="Guidance Scale")
|
| 49 |
+
random_seed_ft = gr.Checkbox(label="Use Random Seed")
|
| 50 |
+
seed_ft = gr.Slider(0, 4294967295, 42, step=1, label="Seed")
|
| 51 |
+
gr.Markdown("#### Model Configuration")
|
| 52 |
+
finetune_model_path_ft = gr.Dropdown(label="Fine-tuned Model", choices=[mid for mid, cfg in model_configs.items() if cfg.get('type') == 'full_finetuning'], value=finetune_model_id)
|
| 53 |
+
with gr.Column(scale=1):
|
| 54 |
+
gr.Markdown("### Generated Result")
|
| 55 |
+
output_image_ft = gr.Image(label="Generated Image", interactive=False, height=512)
|
| 56 |
+
output_text_ft = gr.Textbox(label="Status", interactive=False, lines=3)
|
| 57 |
+
generate_btn_ft = gr.Button("Generate Image", variant="primary")
|
| 58 |
+
stop_btn_ft = gr.Button("Stop Generation")
|
| 59 |
+
gr.Markdown("### Examples for Full Finetuning")
|
| 60 |
+
gr.Examples(examples=examples_full_finetuning, inputs=[prompt_ft, height_ft, width_ft, num_inference_steps_ft, guidance_scale_ft, seed_ft, output_image_ft, finetune_model_path_ft],
|
| 61 |
+
outputs=[prompt_ft, height_ft, width_ft, num_inference_steps_ft, guidance_scale_ft, seed_ft, output_image_ft, finetune_model_path_ft, output_text_ft],
|
| 62 |
+
fn=load_example_image_full_finetuning, cache_examples=False, examples_per_page=4)
|
| 63 |
+
|
| 64 |
+
with gr.Tab(label="LoRA"):
|
| 65 |
+
with gr.Row():
|
| 66 |
+
with gr.Column(scale=1):
|
| 67 |
+
gr.Markdown("### Image Generation Settings")
|
| 68 |
+
prompt_lora = gr.Textbox(label="Prompt", placeholder="e.g., 'a serene landscape in Ghibli style'", lines=2)
|
| 69 |
+
with gr.Group():
|
| 70 |
+
gr.Markdown("#### Image Dimensions")
|
| 71 |
+
with gr.Row():
|
| 72 |
+
height_lora = gr.Slider(32, 4096, 512, step=8, label="Height")
|
| 73 |
+
width_lora = gr.Slider(32, 4096, 512, step=8, label="Width")
|
| 74 |
+
with gr.Accordion("Advanced Settings", open=False):
|
| 75 |
+
num_inference_steps_lora = gr.Slider(1, 100, 50, step=1, label="Inference Steps")
|
| 76 |
+
guidance_scale_lora = gr.Slider(1.0, 20.0, 3.5, step=0.5, label="Guidance Scale")
|
| 77 |
+
lora_scale_lora = gr.Slider(0.0, 2.0, 1.2, step=0.1, label="LoRA Scale")
|
| 78 |
+
random_seed_lora = gr.Checkbox(label="Use Random Seed")
|
| 79 |
+
seed_lora = gr.Slider(0, 4294967295, 42, step=1, label="Seed")
|
| 80 |
+
gr.Markdown("#### Model Configuration")
|
| 81 |
+
lora_model_path_lora = gr.Dropdown(label="LoRA Model", choices=[mid for mid, cfg in model_configs.items() if cfg.get('type') == 'lora'], value=lora_model_id)
|
| 82 |
+
base_model_path_lora = gr.Dropdown(label="Base Model", choices=[model_configs[mid].get('base_model_id') for mid in model_configs if model_configs[mid].get('base_model_id')], value=base_model_id)
|
| 83 |
+
with gr.Column(scale=1):
|
| 84 |
+
gr.Markdown("### Generated Result")
|
| 85 |
+
output_image_lora = gr.Image(label="Generated Image", interactive=False, height=512)
|
| 86 |
+
output_text_lora = gr.Textbox(label="Status", interactive=False, lines=3)
|
| 87 |
+
generate_btn_lora = gr.Button("Generate Image", variant="primary")
|
| 88 |
+
stop_btn_lora = gr.Button("Stop Generation")
|
| 89 |
+
gr.Markdown("### Examples for LoRA")
|
| 90 |
+
gr.Examples(examples=examples_lora, inputs=[prompt_lora, height_lora, width_lora, num_inference_steps_lora, guidance_scale_lora, seed_lora, output_image_lora, lora_model_path_lora, base_model_path_lora, lora_scale_lora],
|
| 91 |
+
outputs=[prompt_lora, height_lora, width_lora, num_inference_steps_lora, guidance_scale_lora, seed_lora, output_image_lora, lora_model_path_lora, base_model_path_lora, lora_scale_lora, output_text_lora],
|
| 92 |
+
fn=load_example_image_lora, cache_examples=False, examples_per_page=4)
|
| 93 |
+
|
| 94 |
+
gr.HTML(outro_markdown_1)
|
| 95 |
+
|
| 96 |
+
generate_event_ft = generate_btn_ft.click(
|
| 97 |
+
fn=generate_image,
|
| 98 |
+
inputs=[prompt_ft, height_ft, width_ft,
|
| 99 |
+
num_inference_steps_ft, guidance_scale_ft, seed_ft,
|
| 100 |
+
random_seed_ft, gr.State(False), finetune_model_path_ft,
|
| 101 |
+
gr.State(None), gr.State(None), gr.State(None),
|
| 102 |
+
gr.State(config_path), gr.State(device), gr.State(dtype)],
|
| 103 |
+
outputs=[output_image_ft, output_text_ft]
|
| 104 |
+
)
|
| 105 |
+
generate_event_lora = generate_btn_lora.click(
|
| 106 |
+
fn=generate_image,
|
| 107 |
+
inputs=[prompt_lora, height_lora, width_lora,
|
| 108 |
+
num_inference_steps_lora, guidance_scale_lora, seed_lora,
|
| 109 |
+
random_seed_lora, gr.State(True), gr.State(None),
|
| 110 |
+
lora_model_path_lora, base_model_path_lora, lora_scale_lora,
|
| 111 |
+
gr.State(config_path), gr.State(device), gr.State(dtype)],
|
| 112 |
+
outputs=[output_image_lora, output_text_lora]
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
stop_btn_ft.click(fn=None, inputs=None, outputs=None, cancels=[generate_event_ft])
|
| 116 |
+
stop_btn_lora.click(fn=None, inputs=None, outputs=None, cancels=[generate_event_lora])
|
| 117 |
+
|
| 118 |
+
demo.unload(lambda: torch.cuda.empty_cache())
|
| 119 |
+
|
| 120 |
+
return demo
|
apps/gradio_app/image_generator.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..',
|
| 6 |
+
'src', 'ghibli_stable_diffusion_synthesis',
|
| 7 |
+
'inference')))
|
| 8 |
+
|
| 9 |
+
from full_finetuning import inference_process as full_finetuning_inference
|
| 10 |
+
from lora import inference_process as lora_inference
|
| 11 |
+
|
| 12 |
+
def generate_image(prompt, height, width, num_inference_steps, guidance_scale, seed,
|
| 13 |
+
random_seed, use_lora, finetune_model_id, lora_model_id, base_model_id,
|
| 14 |
+
lora_scale, config_path, device, dtype):
|
| 15 |
+
batch_size = 1
|
| 16 |
+
if random_seed:
|
| 17 |
+
seed = torch.randint(0, 4294967295, (1,)).item()
|
| 18 |
+
try:
|
| 19 |
+
model_id = finetune_model_id
|
| 20 |
+
if not use_lora:
|
| 21 |
+
pil_image = full_finetuning_inference(
|
| 22 |
+
prompt=prompt,
|
| 23 |
+
height=height,
|
| 24 |
+
width=width,
|
| 25 |
+
num_inference_steps=num_inference_steps,
|
| 26 |
+
guidance_scale=guidance_scale,
|
| 27 |
+
batch_size=batch_size,
|
| 28 |
+
seed=seed,
|
| 29 |
+
config_path=config_path,
|
| 30 |
+
model_id=model_id,
|
| 31 |
+
device=device,
|
| 32 |
+
dtype=dtype
|
| 33 |
+
)
|
| 34 |
+
else:
|
| 35 |
+
model_id = lora_model_id
|
| 36 |
+
pil_image = lora_inference(
|
| 37 |
+
prompt=prompt,
|
| 38 |
+
height=height,
|
| 39 |
+
width=width,
|
| 40 |
+
num_inference_steps=num_inference_steps,
|
| 41 |
+
guidance_scale=guidance_scale,
|
| 42 |
+
batch_size=batch_size,
|
| 43 |
+
seed=seed,
|
| 44 |
+
lora_scale=lora_scale,
|
| 45 |
+
config_path=config_path,
|
| 46 |
+
model_id=model_id,
|
| 47 |
+
# base_model_id=base_model_id,
|
| 48 |
+
device=device,
|
| 49 |
+
dtype=dtype
|
| 50 |
+
)
|
| 51 |
+
return pil_image, f"Generated image successfully! Seed used: {seed}"
|
| 52 |
+
except Exception as e:
|
| 53 |
+
return None, f"Failed to generate image: {e}"
|
| 54 |
+
|
apps/gradio_app/old-image_generator.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from PIL import Image
|
| 3 |
+
import numpy as np
|
| 4 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
| 5 |
+
from diffusers import (
|
| 6 |
+
AutoencoderKL, UNet2DConditionModel,
|
| 7 |
+
PNDMScheduler, StableDiffusionPipeline
|
| 8 |
+
)
|
| 9 |
+
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
from .config_loader import load_model_configs
|
| 12 |
+
|
| 13 |
+
def generate_image(prompt, height, width, num_inference_steps, guidance_scale, seed,
|
| 14 |
+
random_seed, use_lora, finetune_model_id, lora_model_id, base_model_id,
|
| 15 |
+
lora_scale, config_path, device, dtype):
|
| 16 |
+
if not prompt or height % 8 != 0 or width % 8 != 0 or num_inference_steps not in range(1, 101) or \
|
| 17 |
+
guidance_scale < 1.0 or guidance_scale > 20.0 or seed < 0 or seed > 4294967295 or \
|
| 18 |
+
(use_lora and (lora_scale < 0.0 or lora_scale > 2.0)):
|
| 19 |
+
return None, "Invalid input parameters."
|
| 20 |
+
|
| 21 |
+
model_configs = load_model_configs(config_path)
|
| 22 |
+
finetune_model_path = model_configs.get(finetune_model_id, {}).get('local_dir', finetune_model_id)
|
| 23 |
+
lora_model_path = model_configs.get(lora_model_id, {}).get('local_dir', lora_model_id)
|
| 24 |
+
base_model_path = model_configs.get(base_model_id, {}).get('local_dir', base_model_id)
|
| 25 |
+
|
| 26 |
+
generator = torch.Generator(device=device).manual_seed(torch.randint(0, 4294967295, (1,)).item() if random_seed else int(seed))
|
| 27 |
+
|
| 28 |
+
try:
|
| 29 |
+
if use_lora:
|
| 30 |
+
# Load base pipeline
|
| 31 |
+
pipe = StableDiffusionPipeline.from_pretrained(base_model_path, torch_dtype=dtype, use_safetensors=True)
|
| 32 |
+
|
| 33 |
+
# Add LoRA weights with specified rank and scale
|
| 34 |
+
pipe.load_lora_weights(lora_model_path, adapter_name="ghibli-lora",
|
| 35 |
+
lora_scale=lora_scale)
|
| 36 |
+
|
| 37 |
+
pipe = pipe.to(device)
|
| 38 |
+
vae, tokenizer, text_encoder, unet, scheduler = pipe.vae, pipe.tokenizer, pipe.text_encoder, pipe.unet, PNDMScheduler.from_config(pipe.scheduler.config)
|
| 39 |
+
else:
|
| 40 |
+
vae = AutoencoderKL.from_pretrained(finetune_model_path, subfolder="vae", torch_dtype=dtype).to(device)
|
| 41 |
+
tokenizer = CLIPTokenizer.from_pretrained(finetune_model_path, subfolder="tokenizer")
|
| 42 |
+
text_encoder = CLIPTextModel.from_pretrained(finetune_model_path, subfolder="text_encoder", torch_dtype=dtype).to(device)
|
| 43 |
+
unet = UNet2DConditionModel.from_pretrained(finetune_model_path, subfolder="unet", torch_dtype=dtype).to(device)
|
| 44 |
+
scheduler = PNDMScheduler.from_pretrained(finetune_model_path, subfolder="scheduler")
|
| 45 |
+
|
| 46 |
+
text_input = tokenizer([prompt], padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
|
| 47 |
+
text_embeddings = text_encoder(text_input.input_ids.to(device))[0].to(dtype=dtype)
|
| 48 |
+
|
| 49 |
+
uncond_input = tokenizer([""] * 1, padding="max_length", max_length=text_input.input_ids.shape[-1], return_tensors="pt")
|
| 50 |
+
uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0].to(dtype=dtype)
|
| 51 |
+
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
| 52 |
+
|
| 53 |
+
latents = torch.randn((1, unet.config.in_channels, height // 8, width // 8), generator=generator, dtype=dtype, device=device)
|
| 54 |
+
scheduler.set_timesteps(num_inference_steps)
|
| 55 |
+
latents = latents * scheduler.init_noise_sigma
|
| 56 |
+
|
| 57 |
+
for t in tqdm(scheduler.timesteps, desc="Generating image"):
|
| 58 |
+
latent_model_input = torch.cat([latents] * 2)
|
| 59 |
+
latent_model_input = scheduler.scale_model_input(latent_model_input, t)
|
| 60 |
+
noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
| 61 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 62 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 63 |
+
latents = scheduler.step(noise_pred, t, latents).prev_sample
|
| 64 |
+
|
| 65 |
+
image = vae.decode(latents / vae.config.scaling_factor).sample
|
| 66 |
+
image = (image / 2 + 0.5).clamp(0, 1).detach().cpu().permute(0, 2, 3, 1).numpy()
|
| 67 |
+
pil_image = Image.fromarray((image[0] * 255).round().astype("uint8"))
|
| 68 |
+
|
| 69 |
+
if use_lora:
|
| 70 |
+
del pipe
|
| 71 |
+
else:
|
| 72 |
+
del vae, tokenizer, text_encoder, unet, scheduler
|
| 73 |
+
torch.cuda.empty_cache()
|
| 74 |
+
|
| 75 |
+
return pil_image, f"Generated image successfully! Seed used: {seed}"
|
| 76 |
+
except Exception as e:
|
| 77 |
+
return None, f"Failed to generate image: {e}"
|
apps/gradio_app/project_info.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
intro_markdown_1 = """
|
| 2 |
+
<h3>Create Studio Ghibli-style art with Stable Diffusion AI.</h3>
|
| 3 |
+
""".strip()
|
| 4 |
+
|
| 5 |
+
intro_markdown_2 = """
|
| 6 |
+
<div style="text-align: left; font-size: 14px; display: flex; flex-direction: column; gap: 10px;">
|
| 7 |
+
<div style="display: flex; align-items: center; justify-content: left; gap: 8px;">
|
| 8 |
+
You can explore this GitHub Source code: <a href="https://github.com/danhtran2mind/Ghibli-Stable-Diffusion-Synthesis">
|
| 9 |
+
<img src="https://img.shields.io/badge/GitHub-danhtran2mind%2FGhibli--Stable--Diffusion--Synthesis-blue?style=flat&logo=github" alt="GitHub Repo">
|
| 10 |
+
</a>
|
| 11 |
+
</div>
|
| 12 |
+
<div style="display: flex; align-items: center; justify-content: left; gap: 8px;">
|
| 13 |
+
And HuggingFace Model Hubs:
|
| 14 |
+
<a href="https://huggingface.co/danhtran2mind/Ghibli-Stable-Diffusion-2.1-Base-finetuning">
|
| 15 |
+
<img src="https://img.shields.io/badge/HuggingFace-danhtran2mind%2FGhibli--Stable--Diffusion--2.1--Base--finetuning-yellow?style=flat&logo=huggingface" alt="HuggingFace Model Hub">
|
| 16 |
+
</a>, and
|
| 17 |
+
<a href="https://huggingface.co/danhtran2mind/Ghibli-Stable-Diffusion-2.1-LoRA">
|
| 18 |
+
<img src="https://img.shields.io/badge/HuggingFace-danhtran2mind%2FGhibli--Stable--Diffusion--2.1--LoRA-yellow?style=flat&logo=huggingface" alt="HuggingFace Model Hub">
|
| 19 |
+
</a>
|
| 20 |
+
</div>
|
| 21 |
+
</div>
|
| 22 |
+
""".strip()
|
| 23 |
+
|
| 24 |
+
outro_markdown_1 = """
|
| 25 |
+
<div style="text-align: left; font-size: 14px; display: flex; flex-direction: column; gap: 10px;">
|
| 26 |
+
<div style="display: flex; align-items: center; justify-content: left; gap: 8px;">
|
| 27 |
+
This is the pre-trained models on our Hugging Face Model Hubs:
|
| 28 |
+
<a href="https://huggingface.co/stabilityai/stable-diffusion-2-1">
|
| 29 |
+
<img src="https://img.shields.io/badge/HuggingFace-stabilityai%2Fstable--diffusion--2--1-yellow?style=flat&logo=huggingface" alt="HuggingFace Model Hub">
|
| 30 |
+
</a>, and
|
| 31 |
+
<a href="https://huggingface.co/stabilityai/stable-diffusion-2-1-base">
|
| 32 |
+
<img src="https://img.shields.io/badge/HuggingFace-stabilityai%2Fstable--diffusion--2--1--base-yellow?style=flat&logo=huggingface" alt="HuggingFace Model Hub">
|
| 33 |
+
</a>
|
| 34 |
+
</div>
|
| 35 |
+
</div>
|
| 36 |
+
""".strip()
|
apps/gradio_app/setup_scripts.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import subprocess
|
| 2 |
+
import sys
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
def run_script(script_path, args=None):
|
| 6 |
+
"""
|
| 7 |
+
Run a Python script using subprocess with optional arguments and handle errors.
|
| 8 |
+
Returns True if successful, False otherwise.
|
| 9 |
+
"""
|
| 10 |
+
if not os.path.isfile(script_path):
|
| 11 |
+
print(f"Script not found: {script_path}")
|
| 12 |
+
return False
|
| 13 |
+
|
| 14 |
+
try:
|
| 15 |
+
command = [sys.executable, script_path]
|
| 16 |
+
if args:
|
| 17 |
+
command.extend(args)
|
| 18 |
+
result = subprocess.run(
|
| 19 |
+
command,
|
| 20 |
+
check=True,
|
| 21 |
+
text=True,
|
| 22 |
+
capture_output=True
|
| 23 |
+
)
|
| 24 |
+
print(f"Successfully executed {script_path}")
|
| 25 |
+
print(result.stdout)
|
| 26 |
+
return True
|
| 27 |
+
except subprocess.CalledProcessError as e:
|
| 28 |
+
print(f"Error executing {script_path}:")
|
| 29 |
+
print(e.stderr)
|
| 30 |
+
return False
|
| 31 |
+
except Exception as e:
|
| 32 |
+
print(f"Unexpected error executing {script_path}: {str(e)}")
|
| 33 |
+
return False
|
| 34 |
+
|
| 35 |
+
def main():
|
| 36 |
+
"""
|
| 37 |
+
Main function to execute download_ckpts.py with proper error handling.
|
| 38 |
+
"""
|
| 39 |
+
scripts_dir = "scripts"
|
| 40 |
+
scripts = [
|
| 41 |
+
{
|
| 42 |
+
"path": os.path.join(scripts_dir, "download_ckpts.py"),
|
| 43 |
+
"args": [] # Empty list for args to avoid NoneType issues
|
| 44 |
+
},
|
| 45 |
+
# Uncomment and add arguments if needed for setup_third_party.py
|
| 46 |
+
# {
|
| 47 |
+
# "path": os.path.join(scripts_dir, "setup_third_party.py"),
|
| 48 |
+
# "args": []
|
| 49 |
+
# }
|
| 50 |
+
]
|
| 51 |
+
|
| 52 |
+
for script in scripts:
|
| 53 |
+
script_path = script["path"]
|
| 54 |
+
args = script.get("args", []) # Safely get args with default empty list
|
| 55 |
+
print(f"Starting execution of {script_path}{' with args: ' + ' '.join(args) if args else ''}\n")
|
| 56 |
+
|
| 57 |
+
if not run_script(script_path, args):
|
| 58 |
+
print(f"Stopping execution due to error in {script_path}")
|
| 59 |
+
sys.exit(1)
|
| 60 |
+
|
| 61 |
+
print(f"Completed execution of {script_path}\n")
|
| 62 |
+
|
| 63 |
+
if __name__ == "__main__":
|
| 64 |
+
main()
|
apps/gradio_app/static/styles.css
ADDED
|
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
:root {
|
| 2 |
+
--primary-color: #10b981; /* Updated to success color */
|
| 3 |
+
--primary-hover: #0a8f66; /* Darkened shade of #10b981 for hover */
|
| 4 |
+
--accent-color: #8a1bf2; /* New variable for the second gradient color */
|
| 5 |
+
--accent-hover: #6b21a8; /* Darkened shade of #8a1bf2 for hover */
|
| 6 |
+
--secondary-color: #64748b;
|
| 7 |
+
--success-color: #10b981;
|
| 8 |
+
--warning-color: #f59e0b;
|
| 9 |
+
--danger-color: #ef4444;
|
| 10 |
+
--border-radius: 0.5rem; /* Relative unit */
|
| 11 |
+
--shadow-sm: 0 0.0625rem 0.125rem 0 rgba(0, 0, 0, 0.05);
|
| 12 |
+
--shadow-md: 0 0.25rem 0.375rem -0.0625rem rgba(0, 0, 0, 0.1);
|
| 13 |
+
--shadow-lg: 0 0.625rem 0.9375rem -0.1875rem rgba(0, 0, 0, 0.1);
|
| 14 |
+
}
|
| 15 |
+
|
| 16 |
+
/* Container Styles */
|
| 17 |
+
.gradio-container {
|
| 18 |
+
max-width: 75rem !important; /* Relative to viewport */
|
| 19 |
+
margin: 0 auto !important;
|
| 20 |
+
padding: 1.25rem !important; /* Relative padding */
|
| 21 |
+
font-family: 'Segoe UI', system-ui, -apple-system, sans-serif !important;
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
/* Card/Panel Styles */
|
| 25 |
+
.svelte-15lo0d9, .panel {
|
| 26 |
+
background: var(--block-background-fill) !important;
|
| 27 |
+
border-radius: var(--border-radius) !important;
|
| 28 |
+
box-shadow: var(--shadow-md) !important;
|
| 29 |
+
border: 0.0625rem solid var(--border-color-primary) !important;
|
| 30 |
+
backdrop-filter: blur(0.625rem) !important;
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
/* Button Styles */
|
| 34 |
+
button.primary {
|
| 35 |
+
background: linear-gradient(135deg, var(--primary-color), var(--accent-color)) !important;
|
| 36 |
+
border: none !important;
|
| 37 |
+
border-radius: var(--border-radius) !important;
|
| 38 |
+
padding: 0.625rem 1.25rem !important; /* Relative padding */
|
| 39 |
+
font-weight: 600 !important;
|
| 40 |
+
font-size: 1rem !important; /* Relative font size */
|
| 41 |
+
transition: all 0.3s ease !important;
|
| 42 |
+
box-shadow: var(--shadow-sm) !important;
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
button.primary:hover {
|
| 46 |
+
background: linear-gradient(135deg, var(--primary-hover), var(--accent-hover)) !important;
|
| 47 |
+
transform: translateY(-0.0625rem) !important;
|
| 48 |
+
box-shadow: var(--shadow-md) !important;
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
button.secondary {
|
| 52 |
+
background: transparent !important;
|
| 53 |
+
border: 0.0625rem solid var(--border-color-primary) !important;
|
| 54 |
+
border-radius: var(--border-radius) !important;
|
| 55 |
+
color: var(--body-text-color) !important;
|
| 56 |
+
font-weight: 500 !important;
|
| 57 |
+
font-size: 0.875rem !important; /* Relative font size */
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
/* Slider Styles */
|
| 61 |
+
.slider_input_container input[type="range"][name="cowbell"] {
|
| 62 |
+
-webkit-appearance: none !important;
|
| 63 |
+
width: 100% !important;
|
| 64 |
+
height: 0.5rem !important; /* Relative height */
|
| 65 |
+
border-radius: var(--border-radius) !important;
|
| 66 |
+
background: linear-gradient(90deg, var(--primary-color), var(--accent-color)) !important;
|
| 67 |
+
outline: none !important;
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
.slider_input_container input[type="range"][name="cowbell"]::-webkit-slider-thumb {
|
| 71 |
+
-webkit-appearance: none !important;
|
| 72 |
+
width: 1rem !important; /* Relative size */
|
| 73 |
+
height: 1rem !important;
|
| 74 |
+
border-radius: 50% !important;
|
| 75 |
+
background: var(--accent-color) !important;
|
| 76 |
+
cursor: pointer !important;
|
| 77 |
+
box-shadow: var(--shadow-sm) !important;
|
| 78 |
+
border: 0.0625rem solid var(--border-color-primary) !important;
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
.slider_input_container input[type="range"][name="cowbell"]::-webkit-slider-thumb:hover {
|
| 82 |
+
background: var(--accent-color) !important;
|
| 83 |
+
box-shadow: var(--shadow-md) !important;
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
.slider_input_container input[type="range"][name="cowbell"]::-moz-range-track {
|
| 87 |
+
height: 0.5rem !important; /* Relative height */
|
| 88 |
+
border-radius: var(--border-radius) !important;
|
| 89 |
+
background: linear-gradient(90deg, var(--primary-color), var(--accent-color)) !important;
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
.slider_input_container input[type="range"][name="cowbell"]::-moz-range-thumb {
|
| 93 |
+
width: 1rem !important; /* Relative size */
|
| 94 |
+
height: 1rem !important;
|
| 95 |
+
border-radius: 50% !important;
|
| 96 |
+
background: var(--accent-color) !important;
|
| 97 |
+
cursor: pointer !important;
|
| 98 |
+
box-shadow: var(--shadow-sm) !important;
|
| 99 |
+
border: 0.0625rem solid var(--border-color-primary) !important;
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
.slider_input_container input[type="range"][name="cowbell"]::-moz-range-thumb:hover {
|
| 103 |
+
background: var(--accent-color) !important;
|
| 104 |
+
box-shadow: var(--shadow-md) !important;
|
| 105 |
+
}
|
| 106 |
+
|
| 107 |
+
/* Header Styles */
|
| 108 |
+
h1, h2, h3, h4, h5, h6 {
|
| 109 |
+
font-weight: 700 !important;
|
| 110 |
+
color: var(--body-text-color) !important;
|
| 111 |
+
letter-spacing: -0.02em !important;
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
h1 {
|
| 115 |
+
font-size: 2.5rem !important; /* Kept as is, suitable for zooming */
|
| 116 |
+
margin-bottom: 1rem !important;
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
h2 {
|
| 120 |
+
font-size: 1.75rem !important; /* Kept as is, suitable for zooming */
|
| 121 |
+
margin: 1.5rem 0 1rem 0 !important;
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
/* Text Styles */
|
| 125 |
+
p, .prose {
|
| 126 |
+
color: var(--body-text-color-subdued) !important;
|
| 127 |
+
line-height: 1.6 !important;
|
| 128 |
+
font-size: 1rem !important; /* Relative font size */
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
/* Alert/Notification Styles */
|
| 132 |
+
.alert-info {
|
| 133 |
+
background: linear-gradient(135deg, #dbeafe, #bfdbfe) !important;
|
| 134 |
+
border: 0.0625rem solid #93c5fd !important;
|
| 135 |
+
border-radius: var(--border-radius) !important;
|
| 136 |
+
color: #1e40af !important;
|
| 137 |
+
font-size: 0.875rem !important; /* Relative font size */
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
.alert-warning {
|
| 141 |
+
background: linear-gradient(135deg, #fef3c7, #fde68a) !important;
|
| 142 |
+
border: 0.0625rem solid #fcd34d !important;
|
| 143 |
+
border-radius: var(--border-radius) !important;
|
| 144 |
+
color: #92400e !important;
|
| 145 |
+
font-size: 0.875rem !important; /* Relative font size */
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
.alert-error {
|
| 149 |
+
background: linear-gradient(135deg, #fecaca, #fca5a5) !important;
|
| 150 |
+
border: 0.0625rem solid #f87171 !important;
|
| 151 |
+
border-radius: var(--border-radius) !important;
|
| 152 |
+
color: #991b1b !important;
|
| 153 |
+
font-size: 0.875rem !important; /* Relative font size */
|
| 154 |
+
}
|
| 155 |
+
|
| 156 |
+
/* Scrollbar (Webkit browsers) */
|
| 157 |
+
::-webkit-scrollbar {
|
| 158 |
+
width: 0.5rem !important; /* Relative size */
|
| 159 |
+
height: 0.5rem !important;
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
::-webkit-scrollbar-track {
|
| 163 |
+
background: var(--background-fill-secondary) !important;
|
| 164 |
+
border-radius: 0.25rem !important;
|
| 165 |
+
}
|
| 166 |
+
|
| 167 |
+
::-webkit-scrollbar-thumb {
|
| 168 |
+
background: var(--secondary-color) !important;
|
| 169 |
+
border-radius: 0.25rem !important;
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
::-webkit-scrollbar-thumb:hover {
|
| 173 |
+
background: var(--primary-color) !important;
|
| 174 |
+
}
|
| 175 |
+
|
| 176 |
+
/* Tab Styles */
|
| 177 |
+
.gradio-container .tabs button {
|
| 178 |
+
font-size: 1.0625rem !important; /* Relative font size (17px equivalent) */
|
| 179 |
+
font-weight: bold !important;
|
| 180 |
+
}
|
| 181 |
+
|
| 182 |
+
/* Dark Theme Specific Overrides */
|
| 183 |
+
@media (prefers-color-scheme: dark) {
|
| 184 |
+
:root {
|
| 185 |
+
--shadow-sm: 0 0.0625rem 0.125rem 0 rgba(0, 0, 0, 0.2);
|
| 186 |
+
--shadow-md: 0 0.25rem 0.375rem -0.0625rem rgba(0, 0, 0, 0.3);
|
| 187 |
+
--shadow-lg: 0 0.625rem 0.9375rem -0.1875rem rgba(0, 0, 0, 0.3);
|
| 188 |
+
}
|
| 189 |
+
}
|
| 190 |
+
|
| 191 |
+
/* Light Theme Specific Overrides */
|
| 192 |
+
@media (prefers-color-scheme: light) {
|
| 193 |
+
.gradio-container {
|
| 194 |
+
background: linear-gradient(135deg, #f8fafc, #f1f5f9) !important;
|
| 195 |
+
}
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
/* Responsive adjustments for zoom */
|
| 199 |
+
@media screen and (max-width: 48rem) { /* 768px equivalent */
|
| 200 |
+
.gradio-container {
|
| 201 |
+
padding: 0.625rem !important;
|
| 202 |
+
}
|
| 203 |
+
h1 {
|
| 204 |
+
font-size: 2rem !important;
|
| 205 |
+
}
|
| 206 |
+
h2 {
|
| 207 |
+
font-size: 1.5rem !important;
|
| 208 |
+
}
|
| 209 |
+
button.primary {
|
| 210 |
+
padding: 0.5rem 1rem !important;
|
| 211 |
+
font-size: 0.875rem !important;
|
| 212 |
+
}
|
| 213 |
+
}
|
apps/old-gradio_app.py
ADDED
|
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import json
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
import os
|
| 5 |
+
import gradio as gr
|
| 6 |
+
import torch
|
| 7 |
+
from PIL import Image
|
| 8 |
+
import numpy as np
|
| 9 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
| 10 |
+
from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
from transformers import HfArgumentParser
|
| 13 |
+
|
| 14 |
+
def get_examples(examples_dir: str = "apps/gradip_app/assets/examples/ghibli-fine-tuned-sd-2.1") -> list:
|
| 15 |
+
"""
|
| 16 |
+
Load example data from the assets/examples directory.
|
| 17 |
+
Each example is a subdirectory containing a config.json and an image file.
|
| 18 |
+
Returns a list of [prompt, height, width, num_inference_steps, guidance_scale, seed, image_path].
|
| 19 |
+
"""
|
| 20 |
+
# Check if the directory exists
|
| 21 |
+
if not os.path.exists(examples_dir) or not os.path.isdir(examples_dir):
|
| 22 |
+
raise ValueError(f"Directory {examples_dir} does not exist or is not a directory")
|
| 23 |
+
|
| 24 |
+
# Get list of subfolder paths (e.g., 1, 2, etc.)
|
| 25 |
+
all_examples_dir = [os.path.join(examples_dir, d) for d in os.listdir(examples_dir)
|
| 26 |
+
if os.path.isdir(os.path.join(examples_dir, d))]
|
| 27 |
+
|
| 28 |
+
ans = []
|
| 29 |
+
for example_dir in all_examples_dir:
|
| 30 |
+
config_path = os.path.join(example_dir, "config.json")
|
| 31 |
+
image_path = os.path.join(example_dir, "result.png")
|
| 32 |
+
|
| 33 |
+
# Check if config.json and result.png exist
|
| 34 |
+
if not os.path.isfile(config_path):
|
| 35 |
+
print(f"Warning: config.json not found in {example_dir}")
|
| 36 |
+
continue
|
| 37 |
+
if not os.path.isfile(image_path):
|
| 38 |
+
print(f"Warning: result.png not found in {example_dir}")
|
| 39 |
+
continue
|
| 40 |
+
|
| 41 |
+
try:
|
| 42 |
+
with open(config_path, 'r') as f:
|
| 43 |
+
example_dict = json.load(f)
|
| 44 |
+
except (json.JSONDecodeError, IOError) as e:
|
| 45 |
+
print(f"Error reading or parsing {config_path}: {e}")
|
| 46 |
+
continue
|
| 47 |
+
|
| 48 |
+
# Required keys for the config
|
| 49 |
+
required_keys = ["prompt", "height", "width", "num_inference_steps", "guidance_scale", "seed", "image"]
|
| 50 |
+
if not all(key in example_dict for key in required_keys):
|
| 51 |
+
print(f"Warning: Missing required keys in {config_path}")
|
| 52 |
+
continue
|
| 53 |
+
|
| 54 |
+
# Verify that the image key in config.json matches 'result.png'
|
| 55 |
+
if example_dict["image"] != "result.png":
|
| 56 |
+
print(f"Warning: Image key in {config_path} does not match 'result.png'")
|
| 57 |
+
continue
|
| 58 |
+
|
| 59 |
+
try:
|
| 60 |
+
example_list = [
|
| 61 |
+
example_dict["prompt"],
|
| 62 |
+
example_dict["height"],
|
| 63 |
+
example_dict["width"],
|
| 64 |
+
example_dict["num_inference_steps"],
|
| 65 |
+
example_dict["guidance_scale"],
|
| 66 |
+
example_dict["seed"],
|
| 67 |
+
image_path # Use verified image path
|
| 68 |
+
]
|
| 69 |
+
ans.append(example_list)
|
| 70 |
+
except KeyError as e:
|
| 71 |
+
print(f"Error processing {config_path}: Missing key {e}")
|
| 72 |
+
continue
|
| 73 |
+
|
| 74 |
+
if not ans:
|
| 75 |
+
ans = [
|
| 76 |
+
["a serene landscape in Ghibli style", 64, 64, 50, 3.5, 42, None]
|
| 77 |
+
]
|
| 78 |
+
return ans
|
| 79 |
+
|
| 80 |
+
def create_demo(
|
| 81 |
+
model_name: str = "danhtran2mind/ghibli-fine-tuned-sd-2.1",
|
| 82 |
+
device: str = "cuda" if torch.cuda.is_available() else "cpu",
|
| 83 |
+
):
|
| 84 |
+
# Convert device string to torch.device
|
| 85 |
+
device = torch.device(device)
|
| 86 |
+
dtype = torch.float16 if device.type == "cuda" else torch.float32
|
| 87 |
+
|
| 88 |
+
# Load models with consistent dtype
|
| 89 |
+
vae = AutoencoderKL.from_pretrained(model_name, subfolder="vae", torch_dtype=dtype).to(device)
|
| 90 |
+
tokenizer = CLIPTokenizer.from_pretrained(model_name, subfolder="tokenizer")
|
| 91 |
+
text_encoder = CLIPTextModel.from_pretrained(model_name, subfolder="text_encoder", torch_dtype=dtype).to(device)
|
| 92 |
+
unet = UNet2DConditionModel.from_pretrained(model_name, subfolder="unet", torch_dtype=dtype).to(device)
|
| 93 |
+
scheduler = PNDMScheduler.from_pretrained(model_name, subfolder="scheduler")
|
| 94 |
+
|
| 95 |
+
def generate_image(prompt, height, width, num_inference_steps, guidance_scale, seed, random_seed):
|
| 96 |
+
if not prompt:
|
| 97 |
+
return None, "Prompt cannot be empty."
|
| 98 |
+
if height % 8 != 0 or width % 8 != 0:
|
| 99 |
+
return None, "Height and width must be divisible by 8 (e.g., 256, 512, 1024)."
|
| 100 |
+
if num_inference_steps < 1 or num_inference_steps > 100:
|
| 101 |
+
return None, "Number of inference steps must be between 1 and 100."
|
| 102 |
+
if guidance_scale < 1.0 or guidance_scale > 20.0:
|
| 103 |
+
return None, "Guidance scale must be between 1.0 and 20.0."
|
| 104 |
+
if seed < 0 or seed > 4294967295:
|
| 105 |
+
return None, "Seed must be between 0 and 4294967295."
|
| 106 |
+
|
| 107 |
+
batch_size = 1
|
| 108 |
+
if random_seed:
|
| 109 |
+
seed = torch.randint(0, 4294967295, (1,)).item()
|
| 110 |
+
generator = torch.Generator(device=device).manual_seed(int(seed))
|
| 111 |
+
|
| 112 |
+
text_input = tokenizer(
|
| 113 |
+
[prompt], padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt"
|
| 114 |
+
)
|
| 115 |
+
with torch.no_grad():
|
| 116 |
+
text_embeddings = text_encoder(text_input.input_ids.to(device))[0].to(dtype=dtype)
|
| 117 |
+
|
| 118 |
+
max_length = text_input.input_ids.shape[-1]
|
| 119 |
+
uncond_input = tokenizer(
|
| 120 |
+
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
|
| 121 |
+
)
|
| 122 |
+
with torch.no_grad():
|
| 123 |
+
uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0].to(dtype=dtype)
|
| 124 |
+
|
| 125 |
+
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
| 126 |
+
|
| 127 |
+
latents = torch.randn(
|
| 128 |
+
(batch_size, unet.config.in_channels, height // 8, width // 8),
|
| 129 |
+
generator=generator,
|
| 130 |
+
dtype=dtype,
|
| 131 |
+
device=device
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
scheduler.set_timesteps(num_inference_steps)
|
| 135 |
+
latents = latents * scheduler.init_noise_sigma
|
| 136 |
+
|
| 137 |
+
for t in tqdm(scheduler.timesteps, desc="Generating image"):
|
| 138 |
+
latent_model_input = torch.cat([latents] * 2)
|
| 139 |
+
latent_model_input = scheduler.scale_model_input(latent_model_input, t)
|
| 140 |
+
|
| 141 |
+
with torch.no_grad():
|
| 142 |
+
if device.type == "cuda":
|
| 143 |
+
with torch.autocast(device_type="cuda", dtype=torch.float16):
|
| 144 |
+
noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
| 145 |
+
else:
|
| 146 |
+
noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
| 147 |
+
|
| 148 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 149 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 150 |
+
latents = scheduler.step(noise_pred, t, latents).prev_sample
|
| 151 |
+
|
| 152 |
+
with torch.no_grad():
|
| 153 |
+
latents = latents / vae.config.scaling_factor
|
| 154 |
+
image = vae.decode(latents).sample
|
| 155 |
+
|
| 156 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
| 157 |
+
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
|
| 158 |
+
image = (image * 255).round().astype("uint8")
|
| 159 |
+
pil_image = Image.fromarray(image[0])
|
| 160 |
+
|
| 161 |
+
return pil_image, f"Image generated successfully! Seed used: {seed}"
|
| 162 |
+
|
| 163 |
+
def load_example_image(prompt, height, width, num_inference_steps, guidance_scale, seed, image_path):
|
| 164 |
+
"""
|
| 165 |
+
Load the image for the selected example and update input fields.
|
| 166 |
+
"""
|
| 167 |
+
if image_path and Path(image_path).exists():
|
| 168 |
+
try:
|
| 169 |
+
image = Image.open(image_path)
|
| 170 |
+
return prompt, height, width, num_inference_steps, guidance_scale, seed, image, f"Loaded image: {image_path}"
|
| 171 |
+
except Exception as e:
|
| 172 |
+
return prompt, height, width, num_inference_steps, guidance_scale, seed, None, f"Error loading image: {e}"
|
| 173 |
+
return prompt, height, width, num_inference_steps, guidance_scale, seed, None, "No image available"
|
| 174 |
+
|
| 175 |
+
badges_text = r"""
|
| 176 |
+
<div style="text-align: center; display: flex; justify-content: left; gap: 5px;">
|
| 177 |
+
<a href="https://huggingface.co/spaces/danhtran2mind/ghibli-fine-tuned-sd-2.1"><img src="https://img.shields.io/static/v1?label=%F0%9F%A4%97%20Hugging%20Face&message=Space&color=orange"></a>
|
| 178 |
+
</div>
|
| 179 |
+
""".strip()
|
| 180 |
+
|
| 181 |
+
with gr.Blocks() as demo:
|
| 182 |
+
gr.Markdown("# Ghibli-Style Image Generator")
|
| 183 |
+
gr.Markdown(badges_text)
|
| 184 |
+
gr.Markdown("Generate images in Ghibli style using a fine-tuned Stable Diffusion model. Select an example below to load a pre-generated image or enter a prompt to generate a new one.")
|
| 185 |
+
gr.Markdown("""**Note:** For CPU inference, execution time is long (e.g., for resolution 512 × 512) with 50 inference steps, time is approximately 1700 seconds).""")
|
| 186 |
+
|
| 187 |
+
with gr.Row():
|
| 188 |
+
with gr.Column():
|
| 189 |
+
prompt = gr.Textbox(label="Prompt", placeholder="e.g., 'a serene landscape in Ghibli style'")
|
| 190 |
+
with gr.Row():
|
| 191 |
+
width = gr.Slider(32, 4096, 512, step=8, label="Generation Width")
|
| 192 |
+
height = gr.Slider(32, 4096, 512, step=8, label="Generation Height")
|
| 193 |
+
with gr.Accordion("Advanced Options", open=False):
|
| 194 |
+
num_inference_steps = gr.Slider(1, 100, 50, step=1, label="Number of Inference Steps")
|
| 195 |
+
guidance_scale = gr.Slider(1.0, 20.0, 3.5, step=0.5, label="Guidance Scale")
|
| 196 |
+
seed = gr.Number(42, label="Seed (0 to 4294967295)")
|
| 197 |
+
random_seed = gr.Checkbox(label="Use Random Seed", value=False)
|
| 198 |
+
generate_btn = gr.Button("Generate Image")
|
| 199 |
+
|
| 200 |
+
with gr.Column():
|
| 201 |
+
output_image = gr.Image(label="Generated Image")
|
| 202 |
+
output_text = gr.Textbox(label="Status")
|
| 203 |
+
|
| 204 |
+
examples = get_examples("assets/examples/ghibli-fine-tuned-sd-2.1")
|
| 205 |
+
gr.Examples(
|
| 206 |
+
examples=examples,
|
| 207 |
+
inputs=[prompt, height, width, num_inference_steps, guidance_scale, seed, output_image],
|
| 208 |
+
outputs=[prompt, height, width, num_inference_steps, guidance_scale, seed, output_image, output_text],
|
| 209 |
+
fn=load_example_image,
|
| 210 |
+
cache_examples=False
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
generate_btn.click(
|
| 214 |
+
fn=generate_image,
|
| 215 |
+
inputs=[prompt, height, width, num_inference_steps, guidance_scale, seed, random_seed],
|
| 216 |
+
outputs=[output_image, output_text]
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
return demo
|
| 220 |
+
|
| 221 |
+
if __name__ == "__main__":
|
| 222 |
+
parser = argparse.ArgumentParser(description="Ghibli-Style Image Generator using a fine-tuned Stable Diffusion model.")
|
| 223 |
+
parser.add_argument(
|
| 224 |
+
"--local_model",
|
| 225 |
+
action="store_true",
|
| 226 |
+
default=True,
|
| 227 |
+
help="Use local model path instead of Hugging Face model."
|
| 228 |
+
)
|
| 229 |
+
parser.add_argument(
|
| 230 |
+
"--model_name",
|
| 231 |
+
type=str,
|
| 232 |
+
default="danhtran2mind/ghibli-fine-tuned-sd-2.1",
|
| 233 |
+
help="Model name or path for the fine-tuned Stable Diffusion model."
|
| 234 |
+
)
|
| 235 |
+
parser.add_argument(
|
| 236 |
+
"--device",
|
| 237 |
+
type=str,
|
| 238 |
+
default="cuda" if torch.cuda.is_available() else "cpu",
|
| 239 |
+
help="Device to run the model on (e.g., 'cuda', 'cpu')."
|
| 240 |
+
)
|
| 241 |
+
parser.add_argument(
|
| 242 |
+
"--port",
|
| 243 |
+
type=int,
|
| 244 |
+
default=7860,
|
| 245 |
+
help="Port to run the Gradio app on."
|
| 246 |
+
)
|
| 247 |
+
parser.add_argument(
|
| 248 |
+
"--share",
|
| 249 |
+
action="store_true",
|
| 250 |
+
default=False,
|
| 251 |
+
help="Set to True for public sharing (Hugging Face Spaces)."
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
args = parser.parse_args()
|
| 255 |
+
|
| 256 |
+
# Set model_name based on local_model flag
|
| 257 |
+
if args.local_model:
|
| 258 |
+
args.model_name = "./checkpoints/ghibli-fine-tuned-sd-2.1"
|
| 259 |
+
|
| 260 |
+
demo = create_demo(args.model_name, args.device)
|
| 261 |
+
demo.launch(server_port=args.port, share=args.share)
|
apps/old2-gradio_app.py
ADDED
|
@@ -0,0 +1,376 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import json
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
import os
|
| 5 |
+
import gradio as gr
|
| 6 |
+
import torch
|
| 7 |
+
from PIL import Image
|
| 8 |
+
import numpy as np
|
| 9 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
| 10 |
+
from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler, StableDiffusionPipeline
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
import yaml
|
| 13 |
+
|
| 14 |
+
def load_model_configs(config_path: str = "configs/model_ckpts.yaml") -> dict:
|
| 15 |
+
"""
|
| 16 |
+
Load model configurations from a YAML file.
|
| 17 |
+
Returns a dictionary with model IDs and their details.
|
| 18 |
+
"""
|
| 19 |
+
try:
|
| 20 |
+
with open(config_path, 'r') as f:
|
| 21 |
+
configs = yaml.safe_load(f)
|
| 22 |
+
return {cfg['model_id']: cfg for cfg in configs}
|
| 23 |
+
except (IOError, yaml.YAMLError) as e:
|
| 24 |
+
raise ValueError(f"Error loading {config_path}: {e}")
|
| 25 |
+
|
| 26 |
+
def get_examples(examples_dir: str = "apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning") -> list:
|
| 27 |
+
"""
|
| 28 |
+
Load example data from the assets/examples directory.
|
| 29 |
+
Each example is a subdirectory containing a config.json and an image file.
|
| 30 |
+
Returns a list of [prompt, height, width, num_inference_steps, guidance_scale, seed, image_path, use_lora, finetune_model_path, lora_model_path, base_model_path, lora_rank, lora_scale].
|
| 31 |
+
"""
|
| 32 |
+
if not os.path.exists(examples_dir) or not os.path.isdir(examples_dir):
|
| 33 |
+
raise ValueError(f"Directory {examples_dir} does not exist or is not a directory")
|
| 34 |
+
|
| 35 |
+
all_examples_dir = [os.path.join(examples_dir, d) for d in os.listdir(examples_dir)
|
| 36 |
+
if os.path.isdir(os.path.join(examples_dir, d))]
|
| 37 |
+
|
| 38 |
+
ans = []
|
| 39 |
+
for example_dir in all_examples_dir:
|
| 40 |
+
config_path = os.path.join(example_dir, "config.json")
|
| 41 |
+
image_path = os.path.join(example_dir, "result.png")
|
| 42 |
+
|
| 43 |
+
if not os.path.isfile(config_path):
|
| 44 |
+
print(f"Warning: config.json not found in {example_dir}")
|
| 45 |
+
continue
|
| 46 |
+
if not os.path.isfile(image_path):
|
| 47 |
+
print(f"Warning: result.png not found in {example_dir}")
|
| 48 |
+
continue
|
| 49 |
+
|
| 50 |
+
try:
|
| 51 |
+
with open(config_path, 'r') as f:
|
| 52 |
+
example_dict = json.load(f)
|
| 53 |
+
except (json.JSONDecodeError, IOError) as e:
|
| 54 |
+
print(f"Error reading or parsing {config_path}: {e}")
|
| 55 |
+
continue
|
| 56 |
+
|
| 57 |
+
required_keys = ["prompt", "height", "width", "num_inference_steps", "guidance_scale", "seed", "image"]
|
| 58 |
+
if not all(key in example_dict for key in required_keys):
|
| 59 |
+
print(f"Warning: Missing required keys in {config_path}")
|
| 60 |
+
continue
|
| 61 |
+
|
| 62 |
+
if example_dict["image"] != "result.png":
|
| 63 |
+
print(f"Warning: Image key in {config_path} does not match 'result.png'")
|
| 64 |
+
continue
|
| 65 |
+
|
| 66 |
+
try:
|
| 67 |
+
example_list = [
|
| 68 |
+
example_dict["prompt"],
|
| 69 |
+
example_dict["height"],
|
| 70 |
+
example_dict["width"],
|
| 71 |
+
example_dict["num_inference_steps"],
|
| 72 |
+
example_dict["guidance_scale"],
|
| 73 |
+
example_dict["seed"],
|
| 74 |
+
image_path,
|
| 75 |
+
example_dict.get("use_lora", False),
|
| 76 |
+
example_dict.get("finetune_model_path", "danhtran2mind/Ghibli-Stable-Diffusion-2.1-Base-finetuning"),
|
| 77 |
+
example_dict.get("lora_model_path", "danhtran2mind/Ghibli-Stable-Diffusion-2.1-LoRA"),
|
| 78 |
+
example_dict.get("base_model_path", "stabilityai/stable-diffusion-2-1"),
|
| 79 |
+
example_dict.get("lora_rank", 64),
|
| 80 |
+
example_dict.get("lora_scale", 1.2)
|
| 81 |
+
]
|
| 82 |
+
ans.append(example_list)
|
| 83 |
+
except KeyError as e:
|
| 84 |
+
print(f"Error processing {config_path}: Missing key {e}")
|
| 85 |
+
continue
|
| 86 |
+
|
| 87 |
+
if not ans:
|
| 88 |
+
model_configs = load_model_configs("configs/model_ckpts.yaml")
|
| 89 |
+
finetune_model_id = "danhtran2mind/Ghibli-Stable-Diffusion-2.1-Base-finetuning"
|
| 90 |
+
lora_model_id = "danhtran2mind/Ghibli-Stable-Diffusion-2.1-LoRA"
|
| 91 |
+
base_model_id = model_configs[lora_model_id]['base_model_id'] if lora_model_id in model_configs else "stabilityai/stable-diffusion-2-1"
|
| 92 |
+
ans = [
|
| 93 |
+
["a serene landscape in Ghibli style", 512, 512, 50, 3.5, 42, None, False,
|
| 94 |
+
model_configs.get(finetune_model_id, {}).get('local_dir', finetune_model_id),
|
| 95 |
+
model_configs.get(lora_model_id, {}).get('local_dir', lora_model_id),
|
| 96 |
+
base_model_id, 64, 1.2]
|
| 97 |
+
]
|
| 98 |
+
return ans
|
| 99 |
+
|
| 100 |
+
def create_demo(
|
| 101 |
+
config_path: str = "configs/model_ckpts.yaml",
|
| 102 |
+
device: str = "cuda" if torch.cuda.is_available() else "cpu",
|
| 103 |
+
):
|
| 104 |
+
# Load model configurations
|
| 105 |
+
model_configs = load_model_configs(config_path)
|
| 106 |
+
finetune_model_id = "danhtran2mind/Ghibli-Stable-Diffusion-2.1-Base-finetuning"
|
| 107 |
+
lora_model_id = "danhtran2mind/Ghibli-Stable-Diffusion-2.1-LoRA"
|
| 108 |
+
finetune_model_path = model_configs[finetune_model_id]['local_dir'] if model_configs[finetune_model_id]['platform'] == "Local" else finetune_model_id
|
| 109 |
+
lora_model_path = model_configs[lora_model_id]['local_dir'] if model_configs[lora_model_id]['platform'] == "Local" else lora_model_id
|
| 110 |
+
base_model_path = model_configs[lora_model_id]['base_model_id']
|
| 111 |
+
|
| 112 |
+
# Convert device string to torch.device
|
| 113 |
+
device = torch.device(device)
|
| 114 |
+
dtype = torch.float16 if device.type == "cuda" else torch.float32
|
| 115 |
+
|
| 116 |
+
# Extract model IDs for dropdown choices based on type
|
| 117 |
+
finetune_model_ids = [mid for mid, cfg in model_configs.items() if cfg.get('type') == 'full-finetuning']
|
| 118 |
+
lora_model_ids = [mid for mid, cfg in model_configs.items() if cfg.get('type') == 'lora']
|
| 119 |
+
base_model_ids = [model_configs[mid]['base_model_id'] for mid in model_configs if 'base_model_id' in model_configs[mid]]
|
| 120 |
+
|
| 121 |
+
def update_model_path_visibility(use_lora):
|
| 122 |
+
"""
|
| 123 |
+
Update visibility of model path dropdowns based on use_lora checkbox.
|
| 124 |
+
"""
|
| 125 |
+
if use_lora:
|
| 126 |
+
return gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)
|
| 127 |
+
return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)
|
| 128 |
+
|
| 129 |
+
def generate_image(prompt, height, width, num_inference_steps, guidance_scale, seed, random_seed, use_lora, finetune_model_path, lora_model_path, base_model_path, lora_rank, lora_scale):
|
| 130 |
+
if not prompt:
|
| 131 |
+
return None, "Prompt cannot be empty."
|
| 132 |
+
if height % 8 != 0 or width % 8 != 0:
|
| 133 |
+
return None, "Height and width must be divisible by 8 (e.g., 256, 512, 1024)."
|
| 134 |
+
if num_inference_steps < 1 or num_inference_steps > 100:
|
| 135 |
+
return None, "Number of inference steps must be between 1 and 100."
|
| 136 |
+
if guidance_scale < 1.0 or guidance_scale > 20.0:
|
| 137 |
+
return None, "Guidance scale must be between 1.0 and 20.0."
|
| 138 |
+
if seed < 0 or seed > 4294967295:
|
| 139 |
+
return None, "Seed must be between 0 and 4294967295."
|
| 140 |
+
if use_lora and (not lora_model_path or not os.path.exists(lora_model_path) and not lora_model_path.startswith("danhtran2mind/")):
|
| 141 |
+
return None, f"LoRA model path {lora_model_path} does not exist or is invalid."
|
| 142 |
+
if use_lora and (not base_model_path or not os.path.exists(base_model_path) and not base_model_path.startswith("stabilityai/")):
|
| 143 |
+
return None, f"Base model path {base_model_path} does not exist or is invalid."
|
| 144 |
+
if not use_lora and (not finetune_model_path or not os.path.exists(finetune_model_path) and not finetune_model_path.startswith("danhtran2mind/")):
|
| 145 |
+
return None, f"Fine-tuned model path {finetune_model_path} does not exist or is invalid."
|
| 146 |
+
if use_lora and (lora_rank < 1 or lora_rank > 128):
|
| 147 |
+
return None, "LoRA rank must be between 1 and 128."
|
| 148 |
+
if use_lora and (lora_scale < 0.0 or lora_scale > 2.0):
|
| 149 |
+
return None, "LoRA scale must be between 0.0 and 2.0."
|
| 150 |
+
|
| 151 |
+
batch_size = 1
|
| 152 |
+
if random_seed:
|
| 153 |
+
seed = torch.randint(0, 4294967295, (1,)).item()
|
| 154 |
+
generator = torch.Generator(device=device).manual_seed(int(seed))
|
| 155 |
+
|
| 156 |
+
# Load models based on use_lora
|
| 157 |
+
if use_lora:
|
| 158 |
+
try:
|
| 159 |
+
pipe = StableDiffusionPipeline.from_pretrained(
|
| 160 |
+
base_model_path,
|
| 161 |
+
torch_dtype=dtype,
|
| 162 |
+
use_safetensors=True
|
| 163 |
+
)
|
| 164 |
+
pipe.load_lora_weights(lora_model_path, adapter_name="ghibli-lora", lora_scale=lora_scale)
|
| 165 |
+
pipe = pipe.to(device)
|
| 166 |
+
vae = pipe.vae
|
| 167 |
+
tokenizer = pipe.tokenizer
|
| 168 |
+
text_encoder = pipe.text_encoder
|
| 169 |
+
unet = pipe.unet
|
| 170 |
+
scheduler = PNDMScheduler.from_config(pipe.scheduler.config)
|
| 171 |
+
except Exception as e:
|
| 172 |
+
return None, f"Error loading LoRA model from {lora_model_path} or base model from {base_model_path}: {e}"
|
| 173 |
+
else:
|
| 174 |
+
try:
|
| 175 |
+
vae = AutoencoderKL.from_pretrained(finetune_model_path, subfolder="vae", torch_dtype=dtype).to(device)
|
| 176 |
+
tokenizer = CLIPTokenizer.from_pretrained(finetune_model_path, subfolder="tokenizer")
|
| 177 |
+
text_encoder = CLIPTextModel.from_pretrained(finetune_model_path, subfolder="text_encoder", torch_dtype=dtype).to(device)
|
| 178 |
+
unet = UNet2DConditionModel.from_pretrained(finetune_model_path, subfolder="unet", torch_dtype=dtype).to(device)
|
| 179 |
+
scheduler = PNDMScheduler.from_pretrained(finetune_model_path, subfolder="scheduler")
|
| 180 |
+
except Exception as e:
|
| 181 |
+
return None, f"Error loading fine-tuned model from {finetune_model_path}: {e}"
|
| 182 |
+
|
| 183 |
+
text_input = tokenizer(
|
| 184 |
+
[prompt], padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt"
|
| 185 |
+
)
|
| 186 |
+
with torch.no_grad():
|
| 187 |
+
text_embeddings = text_encoder(text_input.input_ids.to(device))[0].to(dtype=dtype)
|
| 188 |
+
|
| 189 |
+
max_length = text_input.input_ids.shape[-1]
|
| 190 |
+
uncond_input = tokenizer(
|
| 191 |
+
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
|
| 192 |
+
)
|
| 193 |
+
with torch.no_grad():
|
| 194 |
+
uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0].to(dtype=dtype)
|
| 195 |
+
|
| 196 |
+
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
| 197 |
+
|
| 198 |
+
latents = torch.randn(
|
| 199 |
+
(batch_size, unet.config.in_channels, height // 8, width // 8),
|
| 200 |
+
generator=generator,
|
| 201 |
+
dtype=dtype,
|
| 202 |
+
device=device
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
scheduler.set_timesteps(num_inference_steps)
|
| 206 |
+
latents = latents * scheduler.init_noise_sigma
|
| 207 |
+
|
| 208 |
+
for t in tqdm(scheduler.timesteps, desc="Generating image"):
|
| 209 |
+
latent_model_input = torch.cat([latents] * 2)
|
| 210 |
+
latent_model_input = scheduler.scale_model_input(latent_model_input, t)
|
| 211 |
+
|
| 212 |
+
with torch.no_grad():
|
| 213 |
+
if device.type == "cuda":
|
| 214 |
+
with torch.autocast(device_type="cuda", dtype=torch.float16):
|
| 215 |
+
noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
| 216 |
+
else:
|
| 217 |
+
noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
| 218 |
+
|
| 219 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 220 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 221 |
+
latents = scheduler.step(noise_pred, t, latents).prev_sample
|
| 222 |
+
|
| 223 |
+
with torch.no_grad():
|
| 224 |
+
latents = latents / vae.config.scaling_factor
|
| 225 |
+
image = vae.decode(latents).sample
|
| 226 |
+
|
| 227 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
| 228 |
+
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
|
| 229 |
+
image = (image * 255).round().astype("uint8")
|
| 230 |
+
pil_image = Image.fromarray(image[0])
|
| 231 |
+
|
| 232 |
+
return pil_image, f"Image generated successfully! Seed used: {seed}"
|
| 233 |
+
|
| 234 |
+
def load_example_image(prompt, height, width, num_inference_steps, guidance_scale, seed, image_path, use_lora, finetune_model_path, lora_model_path, base_model_path, lora_rank, lora_scale):
|
| 235 |
+
"""
|
| 236 |
+
Load the image for the selected example and update input fields.
|
| 237 |
+
"""
|
| 238 |
+
if image_path and Path(image_path).exists():
|
| 239 |
+
try:
|
| 240 |
+
image = Image.open(image_path)
|
| 241 |
+
return (
|
| 242 |
+
prompt, height, width, num_inference_steps, guidance_scale, seed, image,
|
| 243 |
+
use_lora, finetune_model_path, lora_model_path, base_model_path, lora_rank, lora_scale,
|
| 244 |
+
f"Loaded image: {image_path}"
|
| 245 |
+
)
|
| 246 |
+
except Exception as e:
|
| 247 |
+
return (
|
| 248 |
+
prompt, height, width, num_inference_steps, guidance_scale, seed, None,
|
| 249 |
+
use_lora, finetune_model_path, lora_model_path, base_model_path, lora_rank, lora_scale,
|
| 250 |
+
f"Error loading image: {e}"
|
| 251 |
+
)
|
| 252 |
+
return (
|
| 253 |
+
prompt, height, width, num_inference_steps, guidance_scale, seed, None,
|
| 254 |
+
use_lora, finetune_model_path, lora_model_path, base_model_path, lora_rank, lora_scale,
|
| 255 |
+
"No image available"
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
badges_text = r"""
|
| 259 |
+
<div style="text-align: left; font-size: 14px; display: flex; flex-direction: column; gap: 10px;">
|
| 260 |
+
<div style="display: flex; align-items: center; justify-content: left; gap: 8px;">
|
| 261 |
+
You can explore GitHub repository:
|
| 262 |
+
<a href="https://github.com/danhtran2mind/Ghibli-Stable-Diffusion-Synthesis">
|
| 263 |
+
<img src="https://img.shields.io/badge/GitHub-danhtran2mind%2FGhibli--Stable--Diffusion--Synthesis-blue?style=flat&logo=github" alt="GitHub Repo">
|
| 264 |
+
</a>.
|
| 265 |
+
</div>
|
| 266 |
+
<div style="display: flex; align-items: center; justify-content: left; gap: 8px;">
|
| 267 |
+
And you can explore HuggingFace Model Hub:
|
| 268 |
+
<a href="https://huggingface.co/spaces/danhtran2mind/Ghibli-Stable-Diffusion-2.1-Base-finetuning">
|
| 269 |
+
<img src="https://img.shields.io/badge/HuggingFace-danhtran2mind%2FGhibli--Stable--Diffusion--2.1--Base--finetuning-yellow?style=flat&logo=huggingface" alt="HuggingFace Space Demo">
|
| 270 |
+
</a>
|
| 271 |
+
and
|
| 272 |
+
<a href="https://huggingface.co/spaces/danhtran2mind/Ghibli-Stable-Diffusion-2.1-LoRA">
|
| 273 |
+
<img src="https://img.shields.io/badge/HuggingFace-danhtran2mind%2FGhibli--Stable--Diffusion--2.1--LoRA-yellow?style=flat&logo=huggingface" alt="HuggingFace Space Demo">
|
| 274 |
+
</a>
|
| 275 |
+
</div>
|
| 276 |
+
</div>
|
| 277 |
+
""".strip()
|
| 278 |
+
|
| 279 |
+
with gr.Blocks() as demo:
|
| 280 |
+
gr.Markdown("# Ghibli-Style Image Generator")
|
| 281 |
+
gr.Markdown(badges_text)
|
| 282 |
+
gr.Markdown("Generate images in Ghibli style using a fine-tuned Stable Diffusion model or Stable Diffusion 2.1 with LoRA weights. Select an example below to load a pre-generated image or enter a prompt to generate a new one.")
|
| 283 |
+
gr.Markdown("""**Note:** For CPU inference, execution time is long (e.g., for resolution 512 × 512 with 50 inference steps, time is approximately 1700 seconds).""")
|
| 284 |
+
|
| 285 |
+
with gr.Row():
|
| 286 |
+
with gr.Column():
|
| 287 |
+
prompt = gr.Textbox(label="Prompt", placeholder="e.g., 'a serene landscape in Ghibli style'")
|
| 288 |
+
with gr.Row():
|
| 289 |
+
width = gr.Slider(32, 4096, 512, step=8, label="Generation Width")
|
| 290 |
+
height = gr.Slider(32, 4096, 512, step=8, label="Generation Height")
|
| 291 |
+
with gr.Accordion("Advanced Options", open=False):
|
| 292 |
+
num_inference_steps = gr.Slider(1, 100, 50, step=1, label="Number of Inference Steps")
|
| 293 |
+
guidance_scale = gr.Slider(1.0, 20.0, 3.5, step=0.5, label="Guidance Scale")
|
| 294 |
+
seed = gr.Number(42, label="Seed (0 to 4294967295)")
|
| 295 |
+
random_seed = gr.Checkbox(label="Use Random Seed", value=False)
|
| 296 |
+
use_lora = gr.Checkbox(label="Use LoRA Weights", value=False)
|
| 297 |
+
finetune_model_path = gr.Dropdown(
|
| 298 |
+
label="Fine-tuned Model Path",
|
| 299 |
+
choices=finetune_model_ids,
|
| 300 |
+
value=finetune_model_id,
|
| 301 |
+
visible=not use_lora.value
|
| 302 |
+
)
|
| 303 |
+
lora_model_path = gr.Dropdown(
|
| 304 |
+
label="LoRA Model Path",
|
| 305 |
+
choices=lora_model_ids,
|
| 306 |
+
value=lora_model_id,
|
| 307 |
+
visible=use_lora.value
|
| 308 |
+
)
|
| 309 |
+
base_model_path = gr.Dropdown(
|
| 310 |
+
label="Base Model Path",
|
| 311 |
+
choices=base_model_ids,
|
| 312 |
+
value=base_model_path,
|
| 313 |
+
visible=use_lora.value
|
| 314 |
+
)
|
| 315 |
+
lora_rank = gr.Slider(1, 128, 64, step=1, label="LoRA Rank", visible=use_lora.value)
|
| 316 |
+
lora_scale = gr.Slider(0.0, 2.0, 1.2, step=0.1, label="LoRA Scale", visible=use_lora.value)
|
| 317 |
+
generate_btn = gr.Button("Generate Image")
|
| 318 |
+
|
| 319 |
+
with gr.Column():
|
| 320 |
+
output_image = gr.Image(label="Generated Image")
|
| 321 |
+
output_text = gr.Textbox(label="Status")
|
| 322 |
+
|
| 323 |
+
examples = get_examples("assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning")
|
| 324 |
+
gr.Examples(
|
| 325 |
+
examples=examples,
|
| 326 |
+
inputs=[prompt, height, width, num_inference_steps, guidance_scale, seed, output_image, use_lora, finetune_model_path, lora_model_path, base_model_path, lora_rank, lora_scale],
|
| 327 |
+
outputs=[prompt, height, width, num_inference_steps, guidance_scale, seed, output_image, use_lora, finetune_model_path, lora_model_path, base_model_path, lora_rank, lora_scale, output_text],
|
| 328 |
+
fn=load_example_image,
|
| 329 |
+
cache_examples=False
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
use_lora.change(
|
| 333 |
+
fn=update_model_path_visibility,
|
| 334 |
+
inputs=use_lora,
|
| 335 |
+
outputs=[lora_model_path, base_model_path, finetune_model_path]
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
generate_btn.click(
|
| 339 |
+
fn=generate_image,
|
| 340 |
+
inputs=[prompt, height, width, num_inference_steps, guidance_scale, seed, random_seed, use_lora, finetune_model_path, lora_model_path, base_model_path, lora_rank, lora_scale],
|
| 341 |
+
outputs=[output_image, output_text]
|
| 342 |
+
)
|
| 343 |
+
|
| 344 |
+
return demo
|
| 345 |
+
|
| 346 |
+
if __name__ == "__main__":
|
| 347 |
+
parser = argparse.ArgumentParser(description="Ghibli-Style Image Generator using a fine-tuned Stable Diffusion model or Stable Diffusion 2.1 with LoRA weights.")
|
| 348 |
+
parser.add_argument(
|
| 349 |
+
"--config_path",
|
| 350 |
+
type=str,
|
| 351 |
+
default="configs/model_ckpts.yaml",
|
| 352 |
+
help="Path to the model configuration YAML file."
|
| 353 |
+
)
|
| 354 |
+
parser.add_argument(
|
| 355 |
+
"--device",
|
| 356 |
+
type=str,
|
| 357 |
+
default="cuda" if torch.cuda.is_available() else "cpu",
|
| 358 |
+
help="Device to run the model on (e.g., 'cuda', 'cpu')."
|
| 359 |
+
)
|
| 360 |
+
parser.add_argument(
|
| 361 |
+
"--port",
|
| 362 |
+
type=int,
|
| 363 |
+
default=7860,
|
| 364 |
+
help="Port to run the Gradio app on."
|
| 365 |
+
)
|
| 366 |
+
parser.add_argument(
|
| 367 |
+
"--share",
|
| 368 |
+
action="store_true",
|
| 369 |
+
default=False,
|
| 370 |
+
help="Set to True for public sharing (Hugging Face Spaces)."
|
| 371 |
+
)
|
| 372 |
+
|
| 373 |
+
args = parser.parse_args()
|
| 374 |
+
|
| 375 |
+
demo = create_demo(args.config_path, args.device)
|
| 376 |
+
demo.launch(server_port=args.port, share=args.share)
|
apps/old3-gradio_app.py
ADDED
|
@@ -0,0 +1,438 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import json
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
import os
|
| 5 |
+
import gradio as gr
|
| 6 |
+
import torch
|
| 7 |
+
from PIL import Image
|
| 8 |
+
import numpy as np
|
| 9 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
| 10 |
+
from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler, StableDiffusionPipeline
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
import yaml
|
| 13 |
+
|
| 14 |
+
def load_model_configs(config_path: str = "configs/model_ckpts.yaml") -> dict:
|
| 15 |
+
"""
|
| 16 |
+
Load model configurations from a YAML file.
|
| 17 |
+
Returns a dictionary with model IDs and their details.
|
| 18 |
+
"""
|
| 19 |
+
try:
|
| 20 |
+
with open(config_path, 'r') as f:
|
| 21 |
+
configs = yaml.safe_load(f)
|
| 22 |
+
return {cfg['model_id']: cfg for cfg in configs}
|
| 23 |
+
except (IOError, yaml.YAMLError) as e:
|
| 24 |
+
raise ValueError(f"Error loading {config_path}: {e}")
|
| 25 |
+
|
| 26 |
+
def get_examples(examples_dir: str = "apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning") -> list:
|
| 27 |
+
"""
|
| 28 |
+
Load example data from the assets/examples directory.
|
| 29 |
+
Each example is a subdirectory containing a config.json and an image file.
|
| 30 |
+
Returns a list of [prompt, height, width, num_inference_steps, guidance_scale, seed, image_path, use_lora, finetune_model_id, lora_model_id, base_model_id, lora_rank, lora_scale].
|
| 31 |
+
"""
|
| 32 |
+
if not os.path.exists(examples_dir) or not os.path.isdir(examples_dir):
|
| 33 |
+
raise ValueError(f"Directory {examples_dir} does not exist or is not a directory")
|
| 34 |
+
|
| 35 |
+
all_examples_dir = [os.path.join(examples_dir, d) for d in os.listdir(examples_dir)
|
| 36 |
+
if os.path.isdir(os.path.join(examples_dir, d))]
|
| 37 |
+
|
| 38 |
+
ans = []
|
| 39 |
+
for example_dir in all_examples_dir:
|
| 40 |
+
config_path = os.path.join(example_dir, "config.json")
|
| 41 |
+
image_path = os.path.join(example_dir, "result.png")
|
| 42 |
+
|
| 43 |
+
if not os.path.isfile(config_path):
|
| 44 |
+
print(f"Warning: config.json not found in {example_dir}")
|
| 45 |
+
continue
|
| 46 |
+
if not os.path.isfile(image_path):
|
| 47 |
+
print(f"Warning: result.png not found in {example_dir}")
|
| 48 |
+
continue
|
| 49 |
+
|
| 50 |
+
try:
|
| 51 |
+
with open(config_path, 'r') as f:
|
| 52 |
+
example_dict = json.load(f)
|
| 53 |
+
except (json.JSONDecodeError, IOError) as e:
|
| 54 |
+
print(f"Error reading or parsing {config_path}: {e}")
|
| 55 |
+
continue
|
| 56 |
+
|
| 57 |
+
required_keys = ["prompt", "height", "width", "num_inference_steps", "guidance_scale", "seed", "image"]
|
| 58 |
+
if not all(key in example_dict for key in required_keys):
|
| 59 |
+
print(f"Warning: Missing required keys in {config_path}")
|
| 60 |
+
continue
|
| 61 |
+
|
| 62 |
+
if example_dict["image"] != "result.png":
|
| 63 |
+
print(f"Warning: Image key in {config_path} does not match 'result.png'")
|
| 64 |
+
continue
|
| 65 |
+
|
| 66 |
+
try:
|
| 67 |
+
model_configs = load_model_configs("configs/model_ckpts.yaml")
|
| 68 |
+
finetune_model_id = next((mid for mid, cfg in model_configs.items() if cfg.get('type') == 'full_finetuning'), None)
|
| 69 |
+
lora_model_id = next((mid for mid, cfg in model_configs.items() if cfg.get('type') == 'lora'), None)
|
| 70 |
+
|
| 71 |
+
example_list = [
|
| 72 |
+
example_dict["prompt"],
|
| 73 |
+
example_dict["height"],
|
| 74 |
+
example_dict["width"],
|
| 75 |
+
example_dict["num_inference_steps"],
|
| 76 |
+
example_dict["guidance_scale"],
|
| 77 |
+
example_dict["seed"],
|
| 78 |
+
image_path,
|
| 79 |
+
example_dict.get("use_lora", False),
|
| 80 |
+
finetune_model_id if finetune_model_id else "stabilityai/stable-diffusion-2-1-base",
|
| 81 |
+
lora_model_id if lora_model_id else "stabilityai/stable-diffusion-2-1",
|
| 82 |
+
model_configs.get(lora_model_id, {}).get('base_model_id', "stabilityai/stable-diffusion-2-1") if lora_model_id else "stabilityai/stable-diffusion-2-1",
|
| 83 |
+
example_dict.get("lora_rank", 64),
|
| 84 |
+
example_dict.get("lora_scale", 1.2)
|
| 85 |
+
]
|
| 86 |
+
ans.append(example_list)
|
| 87 |
+
except KeyError as e:
|
| 88 |
+
print(f"Error processing {config_path}: Missing key {e}")
|
| 89 |
+
continue
|
| 90 |
+
|
| 91 |
+
if not ans:
|
| 92 |
+
model_configs = load_model_configs("configs/model_ckpts.yaml")
|
| 93 |
+
finetune_model_id = next((mid for mid, cfg in model_configs.items() if cfg.get('type') == 'full_finetuning'), "stabilityai/stable-diffusion-2-1-base")
|
| 94 |
+
lora_model_id = next((mid for mid, cfg in model_configs.items() if cfg.get('type') == 'lora'), "stabilityai/stable-diffusion-2-1")
|
| 95 |
+
base_model_id = model_configs.get(lora_model_id, {}).get('base_model_id', "stabilityai/stable-diffusion-2-1")
|
| 96 |
+
|
| 97 |
+
ans = [
|
| 98 |
+
["a serene landscape in Ghibli style", 512, 512, 50, 3.5, 42, None, False,
|
| 99 |
+
finetune_model_id,
|
| 100 |
+
lora_model_id,
|
| 101 |
+
base_model_id, 64, 1.2]
|
| 102 |
+
]
|
| 103 |
+
return ans
|
| 104 |
+
|
| 105 |
+
def create_demo(
|
| 106 |
+
config_path: str = "configs/model_ckpts.yaml",
|
| 107 |
+
device: str = "cuda" if torch.cuda.is_available() else "cpu",
|
| 108 |
+
):
|
| 109 |
+
# Load model configurations
|
| 110 |
+
model_configs = load_model_configs(config_path)
|
| 111 |
+
|
| 112 |
+
# Load model IDs from YAML
|
| 113 |
+
finetune_model_id = next((mid for mid, cfg in model_configs.items() if cfg.get('type') == 'full_finetuning'), None)
|
| 114 |
+
lora_model_id = next((mid for mid, cfg in model_configs.items() if cfg.get('type') == 'lora'), None)
|
| 115 |
+
|
| 116 |
+
if not finetune_model_id or not lora_model_id:
|
| 117 |
+
raise ValueError("Could not find full_finetuning or lora model IDs in the configuration file.")
|
| 118 |
+
|
| 119 |
+
# Determine finetune model path
|
| 120 |
+
finetune_config = model_configs.get(finetune_model_id, {})
|
| 121 |
+
finetune_local_dir = finetune_config.get('local_dir')
|
| 122 |
+
if finetune_local_dir and os.path.exists(finetune_local_dir) and any(os.path.isfile(os.path.join(finetune_local_dir, f)) for f in os.listdir(finetune_local_dir)):
|
| 123 |
+
finetune_model_path = finetune_local_dir
|
| 124 |
+
else:
|
| 125 |
+
print(f"Local model directory for fine-tuned model '{finetune_model_id}' does not exist or is empty at '{finetune_local_dir}'. Falling back to model ID.")
|
| 126 |
+
finetune_model_path = finetune_model_id
|
| 127 |
+
|
| 128 |
+
# Determine LoRA model path
|
| 129 |
+
lora_config = model_configs.get(lora_model_id, {})
|
| 130 |
+
lora_local_dir = lora_config.get('local_dir')
|
| 131 |
+
if lora_local_dir and os.path.exists(lora_local_dir) and any(os.path.isfile(os.path.join(lora_local_dir, f)) for f in os.listdir(lora_local_dir)):
|
| 132 |
+
lora_model_path = lora_local_dir
|
| 133 |
+
else:
|
| 134 |
+
print(f"Local model directory for LoRA model '{lora_model_id}' does not exist or is empty at '{lora_local_dir}'. Falling back to model ID.")
|
| 135 |
+
lora_model_path = lora_model_id
|
| 136 |
+
|
| 137 |
+
# Determine base model path
|
| 138 |
+
base_model_id = lora_config.get('base_model_id', 'stabilityai/stable-diffusion-2-1')
|
| 139 |
+
base_model_config = model_configs.get(base_model_id, {})
|
| 140 |
+
base_local_dir = base_model_config.get('local_dir')
|
| 141 |
+
if base_local_dir and os.path.exists(base_local_dir) and any(os.path.isfile(os.path.join(base_local_dir, f)) for f in os.listdir(base_local_dir)):
|
| 142 |
+
base_model_path = base_local_dir
|
| 143 |
+
else:
|
| 144 |
+
print(f"Local model directory for base model '{base_model_id}' does not exist or is empty at '{base_local_dir}'. Falling back to model ID.")
|
| 145 |
+
base_model_path = base_model_id
|
| 146 |
+
|
| 147 |
+
# Convert device string to torch.device
|
| 148 |
+
device = torch.device(device)
|
| 149 |
+
dtype = torch.float16 if device.type == "cuda" else torch.float32
|
| 150 |
+
|
| 151 |
+
# Extract model IDs for dropdown choices based on type
|
| 152 |
+
finetune_model_ids = [mid for mid, cfg in model_configs.items() if cfg.get('type') == 'full_finetuning']
|
| 153 |
+
lora_model_ids = [mid for mid, cfg in model_configs.items() if cfg.get('type') == 'lora']
|
| 154 |
+
base_model_ids = [model_configs[mid]['base_model_id'] for mid in model_configs if 'base_model_id' in model_configs[mid]]
|
| 155 |
+
|
| 156 |
+
def update_model_path_visibility(use_lora):
|
| 157 |
+
"""
|
| 158 |
+
Update visibility of model path dropdowns and LoRA sliders based on use_lora checkbox.
|
| 159 |
+
"""
|
| 160 |
+
if use_lora:
|
| 161 |
+
return gr.update(visible=True), gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=True)
|
| 162 |
+
return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)
|
| 163 |
+
|
| 164 |
+
def generate_image(prompt, height, width, num_inference_steps, guidance_scale, seed, random_seed, use_lora, finetune_model_id, lora_model_id, base_model_id, lora_rank, lora_scale):
|
| 165 |
+
# Resolve model paths for generation
|
| 166 |
+
model_configs = load_model_configs(config_path)
|
| 167 |
+
finetune_config = model_configs.get(finetune_model_id, {})
|
| 168 |
+
finetune_local_dir = finetune_config.get('local_dir')
|
| 169 |
+
finetune_model_path = finetune_local_dir if finetune_local_dir and os.path.exists(finetune_local_dir) and any(os.path.isfile(os.path.join(finetune_local_dir, f)) for f in os.listdir(finetune_local_dir)) else finetune_model_id
|
| 170 |
+
|
| 171 |
+
lora_config = model_configs.get(lora_model_id, {})
|
| 172 |
+
lora_local_dir = lora_config.get('local_dir')
|
| 173 |
+
lora_model_path = lora_local_dir if lora_local_dir and os.path.exists(lora_local_dir) and any(os.path.isfile(os.path.join(lora_local_dir, f)) for f in os.listdir(lora_local_dir)) else lora_model_id
|
| 174 |
+
|
| 175 |
+
base_model_config = model_configs.get(base_model_id, {})
|
| 176 |
+
base_local_dir = base_model_config.get('local_dir')
|
| 177 |
+
base_model_path = base_local_dir if base_local_dir and os.path.exists(base_local_dir) and any(os.path.isfile(os.path.join(base_local_dir, f)) for f in os.listdir(base_local_dir)) else base_model_id
|
| 178 |
+
|
| 179 |
+
if not prompt:
|
| 180 |
+
return None, "Prompt cannot be empty."
|
| 181 |
+
if height % 8 != 0 or width % 8 != 0:
|
| 182 |
+
return None, "Height and width must be divisible by 8 (e.g., 256, 512, 1024)."
|
| 183 |
+
if num_inference_steps < 1 or num_inference_steps > 100:
|
| 184 |
+
return None, "Number of inference steps must be between 1 and 100."
|
| 185 |
+
if guidance_scale < 1.0 or guidance_scale > 20.0:
|
| 186 |
+
return None, "Guidance scale must be between 1.0 and 20.0."
|
| 187 |
+
if seed < 0 or seed > 4294967295:
|
| 188 |
+
return None, "Seed must be between 0 and 4294967295."
|
| 189 |
+
if use_lora and (not lora_model_path or not os.path.exists(lora_model_path) and not lora_model_path.startswith("danhtran2mind/")):
|
| 190 |
+
return None, f"LoRA model path {lora_model_path} does not exist or is invalid."
|
| 191 |
+
if use_lora and (not base_model_path or not os.path.exists(base_model_path) and not base_model_path.startswith("stabilityai/")):
|
| 192 |
+
return None, f"Base model path {base_model_path} does not exist or is invalid."
|
| 193 |
+
if not use_lora and (not finetune_model_path or not os.path.exists(finetune_model_path) and not finetune_model_path.startswith("danhtran2mind/")):
|
| 194 |
+
return None, f"Fine-tuned model path {finetune_model_path} does not exist or is invalid."
|
| 195 |
+
if use_lora and (lora_rank < 1 or lora_rank > 128):
|
| 196 |
+
return None, "LoRA rank must be between 1 and 128."
|
| 197 |
+
if use_lora and (lora_scale < 0.0 or lora_scale > 2.0):
|
| 198 |
+
return None, "LoRA scale must be between 0.0 and 2.0."
|
| 199 |
+
|
| 200 |
+
batch_size = 1
|
| 201 |
+
if random_seed:
|
| 202 |
+
seed = torch.randint(0, 4294967295, (1,)).item()
|
| 203 |
+
generator = torch.Generator(device=device).manual_seed(int(seed))
|
| 204 |
+
|
| 205 |
+
# Load models based on use_lora
|
| 206 |
+
if use_lora:
|
| 207 |
+
try:
|
| 208 |
+
pipe = StableDiffusionPipeline.from_pretrained(
|
| 209 |
+
base_model_path,
|
| 210 |
+
torch_dtype=dtype,
|
| 211 |
+
use_safetensors=True
|
| 212 |
+
)
|
| 213 |
+
pipe.load_lora_weights(lora_model_path, adapter_name="ghibli-lora", lora_scale=lora_scale)
|
| 214 |
+
pipe = pipe.to(device)
|
| 215 |
+
vae = pipe.vae
|
| 216 |
+
tokenizer = pipe.tokenizer
|
| 217 |
+
text_encoder = pipe.text_encoder
|
| 218 |
+
unet = pipe.unet
|
| 219 |
+
scheduler = PNDMScheduler.from_config(pipe.scheduler.config)
|
| 220 |
+
except Exception as e:
|
| 221 |
+
return None, f"Error loading LoRA model from {lora_model_path} or base model from {base_model_path}: {e}"
|
| 222 |
+
else:
|
| 223 |
+
try:
|
| 224 |
+
vae = AutoencoderKL.from_pretrained(finetune_model_path, subfolder="vae", torch_dtype=dtype).to(device)
|
| 225 |
+
tokenizer = CLIPTokenizer.from_pretrained(finetune_model_path, subfolder="tokenizer")
|
| 226 |
+
text_encoder = CLIPTextModel.from_pretrained(finetune_model_path, subfolder="text_encoder", torch_dtype=dtype).to(device)
|
| 227 |
+
unet = UNet2DConditionModel.from_pretrained(finetune_model_path, subfolder="unet", torch_dtype=dtype).to(device)
|
| 228 |
+
scheduler = PNDMScheduler.from_pretrained(finetune_model_path, subfolder="scheduler")
|
| 229 |
+
except Exception as e:
|
| 230 |
+
return None, f"Error loading fine-tuned model from {finetune_model_path}: {e}"
|
| 231 |
+
|
| 232 |
+
text_input = tokenizer(
|
| 233 |
+
[prompt], padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt"
|
| 234 |
+
)
|
| 235 |
+
with torch.no_grad():
|
| 236 |
+
text_embeddings = text_encoder(text_input.input_ids.to(device))[0].to(dtype=dtype)
|
| 237 |
+
|
| 238 |
+
max_length = text_input.input_ids.shape[-1]
|
| 239 |
+
uncond_input = tokenizer(
|
| 240 |
+
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
|
| 241 |
+
)
|
| 242 |
+
with torch.no_grad():
|
| 243 |
+
uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0].to(dtype=dtype)
|
| 244 |
+
|
| 245 |
+
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
| 246 |
+
|
| 247 |
+
latents = torch.randn(
|
| 248 |
+
(batch_size, unet.config.in_channels, height // 8, width // 8),
|
| 249 |
+
generator=generator,
|
| 250 |
+
dtype=dtype,
|
| 251 |
+
device=device
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
scheduler.set_timesteps(num_inference_steps)
|
| 255 |
+
latents = latents * scheduler.init_noise_sigma
|
| 256 |
+
|
| 257 |
+
for t in tqdm(scheduler.timesteps, desc="Generating image"):
|
| 258 |
+
latent_model_input = torch.cat([latents] * 2)
|
| 259 |
+
latent_model_input = scheduler.scale_model_input(latent_model_input, t)
|
| 260 |
+
|
| 261 |
+
with torch.no_grad():
|
| 262 |
+
if device.type == "cuda":
|
| 263 |
+
with torch.autocast(device_type="cuda", dtype=torch.float16):
|
| 264 |
+
noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
| 265 |
+
else:
|
| 266 |
+
noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
| 267 |
+
|
| 268 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 269 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 270 |
+
latents = scheduler.step(noise_pred, t, latents).prev_sample
|
| 271 |
+
|
| 272 |
+
with torch.no_grad():
|
| 273 |
+
latents = latents / vae.config.scaling_factor
|
| 274 |
+
image = vae.decode(latents).sample
|
| 275 |
+
|
| 276 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
| 277 |
+
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
|
| 278 |
+
image = (image * 255).round().astype("uint8")
|
| 279 |
+
pil_image = Image.fromarray(image[0])
|
| 280 |
+
|
| 281 |
+
# Success message includes LoRA Path and LoRA Scale when use_lora is True
|
| 282 |
+
if use_lora:
|
| 283 |
+
return pil_image, f"Image generated successfully! Seed used: {seed}, LoRA Path: {lora_model_path}, LoRA Scale: {lora_scale}"
|
| 284 |
+
return pil_image, f"Image generated successfully! Seed used: {seed}"
|
| 285 |
+
|
| 286 |
+
def load_example_image(prompt, height, width, num_inference_steps, guidance_scale, seed, image_path, use_lora, finetune_model_id, lora_model_id, base_model_id, lora_rank, lora_scale):
|
| 287 |
+
"""
|
| 288 |
+
Load the image for the selected example and update input fields.
|
| 289 |
+
"""
|
| 290 |
+
if image_path and Path(image_path).exists():
|
| 291 |
+
try:
|
| 292 |
+
image = Image.open(image_path)
|
| 293 |
+
return (
|
| 294 |
+
prompt, height, width, num_inference_steps, guidance_scale, seed, image,
|
| 295 |
+
use_lora, finetune_model_id, lora_model_id, base_model_id, lora_rank, lora_scale,
|
| 296 |
+
f"Loaded image: {image_path}"
|
| 297 |
+
)
|
| 298 |
+
except Exception as e:
|
| 299 |
+
return (
|
| 300 |
+
prompt, height, width, num_inference_steps, guidance_scale, seed, None,
|
| 301 |
+
use_lora, finetune_model_id, lora_model_id, base_model_id, lora_rank, lora_scale,
|
| 302 |
+
f"Error loading image: {e}"
|
| 303 |
+
)
|
| 304 |
+
return (
|
| 305 |
+
prompt, height, width, num_inference_steps, guidance_scale, seed, None,
|
| 306 |
+
use_lora, finetune_model_id, lora_model_id, base_model_id, lora_rank, lora_scale,
|
| 307 |
+
"No image available"
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
badges_text = r"""
|
| 311 |
+
<div style="text-align: left; font-size: 14px; display: flex; flex-direction: column; gap: 10px;">
|
| 312 |
+
<div style="display: flex; align-items: center; justify-content: left; gap: 8px;">
|
| 313 |
+
You can explore GitHub repository:
|
| 314 |
+
<a href="https://github.com/danhtran2mind/Ghibli-Stable-Diffusion-Synthesis">
|
| 315 |
+
<img src="https://img.shields.io/badge/GitHub-danhtran2mind%2FGhibli--Stable--Diffusion--Synthesis-blue?style=flat&logo=github" alt="GitHub Repo">
|
| 316 |
+
</a>. And you can explore HuggingFace Model Hub:
|
| 317 |
+
<a href="https://huggingface.co/spaces/danhtran2mind/Ghibli-Stable-Diffusion-2.1-Base-finetuning">
|
| 318 |
+
<img src="https://img.shields.io/badge/HuggingFace-danhtran2mind%2FGhibli--Stable--Diffusion--2.1--Base--finetuning-yellow?style=flat&logo=huggingface" alt="HuggingFace Space Demo">
|
| 319 |
+
</a>
|
| 320 |
+
and
|
| 321 |
+
<a href="https://huggingface.co/spaces/danhtran2mind/Ghibli-Stable-Diffusion-2.1-LoRA">
|
| 322 |
+
<img src="https://img.shields.io/badge/HuggingFace-danhtran2mind%2FGhibli--Stable--Diffusion--2.1--LoRA-yellow?style=flat&logo=huggingface" alt="HuggingFace Space Demo">
|
| 323 |
+
</a>
|
| 324 |
+
</div>
|
| 325 |
+
</div>
|
| 326 |
+
""".strip()
|
| 327 |
+
|
| 328 |
+
with gr.Blocks() as demo:
|
| 329 |
+
gr.Markdown("# Ghibli-Style Image Generator")
|
| 330 |
+
gr.Markdown(badges_text)
|
| 331 |
+
gr.Markdown("Generate images in Ghibli style using a fine-tuned Stable Diffusion model or Stable Diffusion 2.1 with LoRA weights. Select an example below to load a pre-generated image or enter a prompt to generate a new one.")
|
| 332 |
+
gr.Markdown("""**Note:** For CPU inference, execution time is long (e.g., for resolution 512 × 512 with 50 inference steps, time is approximately 1700 seconds).""")
|
| 333 |
+
|
| 334 |
+
with gr.Row():
|
| 335 |
+
with gr.Column():
|
| 336 |
+
prompt = gr.Textbox(label="Prompt", placeholder="e.g., 'a serene landscape in Ghibli style'")
|
| 337 |
+
with gr.Row():
|
| 338 |
+
width = gr.Slider(32, 4096, 512, step=8, label="Generation Width")
|
| 339 |
+
height = gr.Slider(32, 4096, 512, step=8, label="Generation Height")
|
| 340 |
+
with gr.Accordion("Advanced Options", open=False):
|
| 341 |
+
num_inference_steps = gr.Slider(1, 100, 50, step=1, label="Number of Inference Steps")
|
| 342 |
+
guidance_scale = gr.Slider(1.0, 20.0, 3.5, step=0.5, label="Guidance Scale")
|
| 343 |
+
seed = gr.Number(42, label="Seed (0 to 4294967295)")
|
| 344 |
+
random_seed = gr.Checkbox(label="Use Random Seed", value=False)
|
| 345 |
+
use_lora = gr.Checkbox(label="Use LoRA Weights", value=False)
|
| 346 |
+
finetune_model_path = gr.Dropdown(
|
| 347 |
+
label="Fine-tuned Model Path",
|
| 348 |
+
choices=finetune_model_ids,
|
| 349 |
+
value=finetune_model_id,
|
| 350 |
+
visible=not use_lora.value
|
| 351 |
+
)
|
| 352 |
+
lora_model_path = gr.Dropdown(
|
| 353 |
+
label="LoRA Model Path",
|
| 354 |
+
choices=lora_model_ids,
|
| 355 |
+
value=lora_model_id,
|
| 356 |
+
visible=use_lora.value
|
| 357 |
+
)
|
| 358 |
+
base_model_path = gr.Dropdown(
|
| 359 |
+
label="Base Model Path",
|
| 360 |
+
choices=base_model_ids,
|
| 361 |
+
value=base_model_id,
|
| 362 |
+
visible=use_lora.value
|
| 363 |
+
)
|
| 364 |
+
|
| 365 |
+
with gr.Group(visible=use_lora.value):
|
| 366 |
+
gr.Markdown("### LoRA Configuration")
|
| 367 |
+
lora_rank = gr.Slider(
|
| 368 |
+
1, 128, 64, step=1,
|
| 369 |
+
label="LoRA Rank (controls model complexity)",
|
| 370 |
+
visible=use_lora.value,
|
| 371 |
+
info="Adjusts the rank of LoRA weights, affecting model complexity and memory usage."
|
| 372 |
+
)
|
| 373 |
+
lora_scale = gr.Slider(
|
| 374 |
+
0.0, 2.0, 1.2, step=0.1,
|
| 375 |
+
label="LoRA Scale (controls weight influence)",
|
| 376 |
+
visible=use_lora.value,
|
| 377 |
+
info="Adjusts the influence of LoRA weights on the base model."
|
| 378 |
+
)
|
| 379 |
+
generate_btn = gr.Button("Generate Image")
|
| 380 |
+
|
| 381 |
+
with gr.Column():
|
| 382 |
+
output_image = gr.Image(label="Generated Image")
|
| 383 |
+
output_text = gr.Textbox(label="Status")
|
| 384 |
+
|
| 385 |
+
examples = get_examples("apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning")
|
| 386 |
+
gr.Examples(
|
| 387 |
+
examples=examples,
|
| 388 |
+
inputs=[prompt, height, width, num_inference_steps, guidance_scale, seed, output_image, use_lora, finetune_model_path, lora_model_path, base_model_path, lora_rank, lora_scale],
|
| 389 |
+
outputs=[prompt, height, width, num_inference_steps, guidance_scale, seed, output_image, use_lora, finetune_model_path, lora_model_path, base_model_path, lora_rank, lora_scale, output_text],
|
| 390 |
+
fn=load_example_image,
|
| 391 |
+
cache_examples=False
|
| 392 |
+
)
|
| 393 |
+
|
| 394 |
+
use_lora.change(
|
| 395 |
+
fn=update_model_path_visibility,
|
| 396 |
+
inputs=use_lora,
|
| 397 |
+
outputs=[lora_model_path, base_model_path, finetune_model_path, lora_rank, lora_scale]
|
| 398 |
+
)
|
| 399 |
+
|
| 400 |
+
generate_btn.click(
|
| 401 |
+
fn=generate_image,
|
| 402 |
+
inputs=[prompt, height, width, num_inference_steps, guidance_scale, seed, random_seed, use_lora, finetune_model_path, lora_model_path, base_model_path, lora_rank, lora_scale],
|
| 403 |
+
outputs=[output_image, output_text]
|
| 404 |
+
)
|
| 405 |
+
|
| 406 |
+
return demo
|
| 407 |
+
|
| 408 |
+
if __name__ == "__main__":
|
| 409 |
+
parser = argparse.ArgumentParser(description="Ghibli-Style Image Generator using a fine-tuned Stable Diffusion model or Stable Diffusion 2.1 with LoRA weights.")
|
| 410 |
+
parser.add_argument(
|
| 411 |
+
"--config_path",
|
| 412 |
+
type=str,
|
| 413 |
+
default="configs/model_ckpts.yaml",
|
| 414 |
+
help="Path to the model configuration YAML file."
|
| 415 |
+
)
|
| 416 |
+
parser.add_argument(
|
| 417 |
+
"--device",
|
| 418 |
+
type=str,
|
| 419 |
+
default="cuda" if torch.cuda.is_available() else "cpu",
|
| 420 |
+
help="Device to run the model on (e.g., 'cuda', 'cpu')."
|
| 421 |
+
)
|
| 422 |
+
parser.add_argument(
|
| 423 |
+
"--port",
|
| 424 |
+
type=int,
|
| 425 |
+
default=7860,
|
| 426 |
+
help="Port to run the Gradio app on."
|
| 427 |
+
)
|
| 428 |
+
parser.add_argument(
|
| 429 |
+
"--share",
|
| 430 |
+
action="store_true",
|
| 431 |
+
default=False,
|
| 432 |
+
help="Set to True for public sharing (Hugging Face Spaces)."
|
| 433 |
+
)
|
| 434 |
+
|
| 435 |
+
args = parser.parse_args()
|
| 436 |
+
|
| 437 |
+
demo = create_demo(args.config_path, args.device)
|
| 438 |
+
demo.launch(server_port=args.port, share=args.share)
|
apps/old4-gradio_app.py
ADDED
|
@@ -0,0 +1,548 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import json
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
import os
|
| 5 |
+
import gradio as gr
|
| 6 |
+
import torch
|
| 7 |
+
from PIL import Image
|
| 8 |
+
import numpy as np
|
| 9 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
| 10 |
+
from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler, StableDiffusionPipeline
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
import yaml
|
| 13 |
+
|
| 14 |
+
def load_model_configs(config_path: str = "configs/model_ckpts.yaml") -> dict:
|
| 15 |
+
"""
|
| 16 |
+
Load model configurations from a YAML file.
|
| 17 |
+
Returns a dictionary with model IDs and their details.
|
| 18 |
+
"""
|
| 19 |
+
try:
|
| 20 |
+
with open(config_path, 'r') as f:
|
| 21 |
+
configs = yaml.safe_load(f)
|
| 22 |
+
return {cfg['model_id']: cfg for cfg in configs}
|
| 23 |
+
except (IOError, yaml.YAMLError) as e:
|
| 24 |
+
raise ValueError(f"Error loading {config_path}: {e}")
|
| 25 |
+
|
| 26 |
+
def get_examples(examples_dir: str = "apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning") -> list:
|
| 27 |
+
|
| 28 |
+
if not os.path.exists(examples_dir) or not os.path.isdir(examples_dir):
|
| 29 |
+
raise ValueError(f"Directory {examples_dir} does not exist or is not a directory")
|
| 30 |
+
|
| 31 |
+
all_examples_dir = [os.path.join(examples_dir, d) for d in os.listdir(examples_dir)
|
| 32 |
+
if os.path.isdir(os.path.join(examples_dir, d))]
|
| 33 |
+
|
| 34 |
+
ans = []
|
| 35 |
+
for example_dir in all_examples_dir:
|
| 36 |
+
config_path = os.path.join(example_dir, "config.json")
|
| 37 |
+
image_path = os.path.join(example_dir, "result.png")
|
| 38 |
+
|
| 39 |
+
if not os.path.isfile(config_path):
|
| 40 |
+
print(f"Warning: config.json not found in {example_dir}")
|
| 41 |
+
continue
|
| 42 |
+
if not os.path.isfile(image_path):
|
| 43 |
+
print(f"Warning: result.png not found in {example_dir}")
|
| 44 |
+
continue
|
| 45 |
+
|
| 46 |
+
try:
|
| 47 |
+
with open(config_path, 'r') as f:
|
| 48 |
+
example_dict = json.load(f)
|
| 49 |
+
except (json.JSONDecodeError, IOError) as e:
|
| 50 |
+
print(f"Error reading or parsing {config_path}: {e}")
|
| 51 |
+
continue
|
| 52 |
+
|
| 53 |
+
# Required keys for all configs
|
| 54 |
+
required_keys = ["prompt", "height", "width", "num_inference_steps", "guidance_scale", "seed", "image"]
|
| 55 |
+
if not all(key in example_dict for key in required_keys):
|
| 56 |
+
print(f"Warning: Missing required keys in {config_path}")
|
| 57 |
+
continue
|
| 58 |
+
|
| 59 |
+
if example_dict["image"] != "result.png":
|
| 60 |
+
print(f"Warning: Image key in {config_path} does not match 'result.png'")
|
| 61 |
+
continue
|
| 62 |
+
|
| 63 |
+
try:
|
| 64 |
+
use_lora = example_dict.get("use_lora", False)
|
| 65 |
+
example_list = [
|
| 66 |
+
example_dict["prompt"],
|
| 67 |
+
example_dict["height"],
|
| 68 |
+
example_dict["width"],
|
| 69 |
+
example_dict["num_inference_steps"],
|
| 70 |
+
example_dict["guidance_scale"],
|
| 71 |
+
example_dict["seed"],
|
| 72 |
+
image_path,
|
| 73 |
+
use_lora
|
| 74 |
+
]
|
| 75 |
+
|
| 76 |
+
if use_lora:
|
| 77 |
+
# Additional required keys for LoRA config
|
| 78 |
+
lora_required_keys = ["lora_model_id", "base_model_id", "lora_rank", "lora_scale"]
|
| 79 |
+
if not all(key in example_dict for key in lora_required_keys):
|
| 80 |
+
print(f"Warning: Missing required LoRA keys in {config_path}")
|
| 81 |
+
continue
|
| 82 |
+
|
| 83 |
+
example_list.extend([
|
| 84 |
+
None, # finetune_model_id (not used for LoRA)
|
| 85 |
+
example_dict["lora_model_id"],
|
| 86 |
+
example_dict["base_model_id"],
|
| 87 |
+
example_dict["lora_rank"],
|
| 88 |
+
example_dict["lora_scale"]
|
| 89 |
+
])
|
| 90 |
+
else:
|
| 91 |
+
# Additional required key for non-LoRA config
|
| 92 |
+
if "finetune_model_id" not in example_dict:
|
| 93 |
+
print(f"Warning: Missing finetune_model_id in {config_path}")
|
| 94 |
+
continue
|
| 95 |
+
|
| 96 |
+
example_list.extend([
|
| 97 |
+
example_dict["finetune_model_id"],
|
| 98 |
+
None, # lora_model_id
|
| 99 |
+
None, # base_model_id
|
| 100 |
+
None, # lora_rank
|
| 101 |
+
None # lora_scale
|
| 102 |
+
])
|
| 103 |
+
|
| 104 |
+
ans.append(example_list)
|
| 105 |
+
except KeyError as e:
|
| 106 |
+
print(f"Error processing {config_path}: Missing key {e}")
|
| 107 |
+
continue
|
| 108 |
+
|
| 109 |
+
if not ans:
|
| 110 |
+
# Default example for non-LoRA
|
| 111 |
+
ans = [
|
| 112 |
+
["a serene landscape in Ghibli style", 512, 512, 50, 3.5, 42, None, False,
|
| 113 |
+
"stabilityai/stable-diffusion-2-1-base",
|
| 114 |
+
None, None, None, None]
|
| 115 |
+
]
|
| 116 |
+
# Default example for LoRA
|
| 117 |
+
ans.append(
|
| 118 |
+
["a serene landscape in Ghibli style", 512, 512, 50, 3.5, 42, None, True,
|
| 119 |
+
None,
|
| 120 |
+
"stabilityai/stable-diffusion-2-1",
|
| 121 |
+
"stabilityai/stable-diffusion-2-1",
|
| 122 |
+
64, 1.2]
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
return ans
|
| 126 |
+
|
| 127 |
+
def create_demo(
|
| 128 |
+
config_path: str = "configs/model_ckpts.yaml",
|
| 129 |
+
device: str = "cuda" if torch.cuda.is_available() else "cpu",
|
| 130 |
+
):
|
| 131 |
+
# Load model configurations
|
| 132 |
+
model_configs = load_model_configs(config_path)
|
| 133 |
+
|
| 134 |
+
# Load model IDs from YAML
|
| 135 |
+
finetune_model_id = next((mid for mid, cfg in model_configs.items() if cfg.get('type') == 'full_finetuning'), None)
|
| 136 |
+
lora_model_id = next((mid for mid, cfg in model_configs.items() if cfg.get('type') == 'lora'), None)
|
| 137 |
+
|
| 138 |
+
if not finetune_model_id or not lora_model_id:
|
| 139 |
+
raise ValueError("Could not find full_finetuning or lora model IDs in the configuration file.")
|
| 140 |
+
|
| 141 |
+
# Determine finetune model path
|
| 142 |
+
finetune_config = model_configs.get(finetune_model_id, {})
|
| 143 |
+
finetune_local_dir = finetune_config.get('local_dir')
|
| 144 |
+
if finetune_local_dir and os.path.exists(finetune_local_dir) and any(os.path.isfile(os.path.join(finetune_local_dir, f)) for f in os.listdir(finetune_local_dir)):
|
| 145 |
+
finetune_model_path = finetune_local_dir
|
| 146 |
+
else:
|
| 147 |
+
print(f"Local model directory for fine-tuned model '{finetune_model_id}' does not exist or is empty at '{finetune_local_dir}'. Falling back to model ID.")
|
| 148 |
+
finetune_model_path = finetune_model_id
|
| 149 |
+
|
| 150 |
+
# Determine LoRA model path
|
| 151 |
+
lora_config = model_configs.get(lora_model_id, {})
|
| 152 |
+
lora_local_dir = lora_config.get('local_dir')
|
| 153 |
+
if lora_local_dir and os.path.exists(lora_local_dir) and any(os.path.isfile(os.path.join(lora_local_dir, f)) for f in os.listdir(lora_local_dir)):
|
| 154 |
+
lora_model_path = lora_local_dir
|
| 155 |
+
else:
|
| 156 |
+
print(f"Local model directory for LoRA model '{lora_model_id}' does not exist or is empty at '{lora_local_dir}'. Falling back to model ID.")
|
| 157 |
+
lora_model_path = lora_model_id
|
| 158 |
+
|
| 159 |
+
# Determine base model path
|
| 160 |
+
base_model_id = lora_config.get('base_model_id', 'stabilityai/stable-diffusion-2-1')
|
| 161 |
+
base_model_config = model_configs.get(base_model_id, {})
|
| 162 |
+
base_local_dir = base_model_config.get('local_dir')
|
| 163 |
+
if base_local_dir and os.path.exists(base_local_dir) and any(os.path.isfile(os.path.join(base_local_dir, f)) for f in os.listdir(base_local_dir)):
|
| 164 |
+
base_model_path = base_local_dir
|
| 165 |
+
else:
|
| 166 |
+
print(f"Local model directory for base model '{base_model_id}' does not exist or is empty at '{base_local_dir}'. Falling back to model ID.")
|
| 167 |
+
base_model_path = base_model_id
|
| 168 |
+
|
| 169 |
+
# Convert device string to torch.device
|
| 170 |
+
device = torch.device(device)
|
| 171 |
+
dtype = torch.float16 if device.type == "cuda" else torch.float32
|
| 172 |
+
|
| 173 |
+
# Extract model IDs for dropdown choices based on type
|
| 174 |
+
finetune_model_ids = [mid for mid, cfg in model_configs.items() if cfg.get('type') == 'full_finetuning']
|
| 175 |
+
lora_model_ids = [mid for mid, cfg in model_configs.items() if cfg.get('type') == 'lora']
|
| 176 |
+
base_model_ids = [model_configs[mid]['base_model_id'] for mid in model_configs if 'base_model_id' in model_configs[mid]]
|
| 177 |
+
|
| 178 |
+
def update_model_path_visibility(use_lora):
|
| 179 |
+
"""
|
| 180 |
+
Update visibility of model path dropdowns and LoRA sliders based on use_lora checkbox.
|
| 181 |
+
"""
|
| 182 |
+
if use_lora:
|
| 183 |
+
return gr.update(visible=True), gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=True)
|
| 184 |
+
return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)
|
| 185 |
+
|
| 186 |
+
def generate_image(prompt, height, width, num_inference_steps, guidance_scale, seed, random_seed, use_lora, finetune_model_id, lora_model_id, base_model_id, lora_rank, lora_scale):
|
| 187 |
+
# Resolve model paths for generation
|
| 188 |
+
model_configs = load_model_configs(config_path)
|
| 189 |
+
finetune_config = model_configs.get(finetune_model_id, {})
|
| 190 |
+
finetune_local_dir = finetune_config.get('local_dir')
|
| 191 |
+
finetune_model_path = finetune_local_dir if finetune_local_dir and os.path.exists(finetune_local_dir) and any(os.path.isfile(os.path.join(finetune_local_dir, f)) for f in os.listdir(finetune_local_dir)) else finetune_model_id
|
| 192 |
+
|
| 193 |
+
lora_config = model_configs.get(lora_model_id, {})
|
| 194 |
+
lora_local_dir = lora_config.get('local_dir')
|
| 195 |
+
lora_model_path = lora_local_dir if lora_local_dir and os.path.exists(lora_local_dir) and any(os.path.isfile(os.path.join(lora_local_dir, f)) for f in os.listdir(lora_local_dir)) else lora_model_id
|
| 196 |
+
|
| 197 |
+
base_model_config = model_configs.get(base_model_id, {})
|
| 198 |
+
base_local_dir = base_model_config.get('local_dir')
|
| 199 |
+
base_model_path = base_local_dir if base_local_dir and os.path.exists(base_local_dir) and any(os.path.isfile(os.path.join(base_local_dir, f)) for f in os.listdir(base_local_dir)) else base_model_id
|
| 200 |
+
|
| 201 |
+
if not prompt:
|
| 202 |
+
return None, "Prompt cannot be empty."
|
| 203 |
+
if height % 8 != 0 or width % 8 != 0:
|
| 204 |
+
return None, "Height and width must be divisible by 8 (e.g., 256, 512, 1024)."
|
| 205 |
+
if num_inference_steps < 1 or num_inference_steps > 100:
|
| 206 |
+
return None, "Number of inference steps must be between 1 and 100."
|
| 207 |
+
if guidance_scale < 1.0 or guidance_scale > 20.0:
|
| 208 |
+
return None, "Guidance scale must be between 1.0 and 20.0."
|
| 209 |
+
if seed < 0 or seed > 4294967295:
|
| 210 |
+
return None, "Seed must be between 0 and 4294967295."
|
| 211 |
+
if use_lora and (not lora_model_path or not os.path.exists(lora_model_path) and not lora_model_path.startswith("danhtran2mind/")):
|
| 212 |
+
return None, f"LoRA model path {lora_model_path} does not exist or is invalid."
|
| 213 |
+
if use_lora and (not base_model_path or not os.path.exists(base_model_path) and not base_model_path.startswith("stabilityai/")):
|
| 214 |
+
return None, f"Base model path {base_model_path} does not exist or is invalid."
|
| 215 |
+
if not use_lora and (not finetune_model_path or not os.path.exists(finetune_model_path) and not finetune_model_path.startswith("danhtran2mind/")):
|
| 216 |
+
return None, f"Fine-tuned model path {finetune_model_path} does not exist or is invalid."
|
| 217 |
+
if use_lora and (lora_rank < 1 or lora_rank > 128):
|
| 218 |
+
return None, "LoRA rank must be between 1 and 128."
|
| 219 |
+
if use_lora and (lora_scale < 0.0 or lora_scale > 2.0):
|
| 220 |
+
return None, "LoRA scale must be between 0.0 and 2.0."
|
| 221 |
+
|
| 222 |
+
batch_size = 1
|
| 223 |
+
if random_seed:
|
| 224 |
+
seed = torch.randint(0, 4294967295, (1,)).item()
|
| 225 |
+
generator = torch.Generator(device=device).manual_seed(int(seed))
|
| 226 |
+
|
| 227 |
+
# Load models based on use_lora
|
| 228 |
+
if use_lora:
|
| 229 |
+
try:
|
| 230 |
+
pipe = StableDiffusionPipeline.from_pretrained(
|
| 231 |
+
base_model_path,
|
| 232 |
+
torch_dtype=dtype,
|
| 233 |
+
use_safetensors=True
|
| 234 |
+
)
|
| 235 |
+
pipe.load_lora_weights(lora_model_path, adapter_name="ghibli-lora", lora_scale=lora_scale)
|
| 236 |
+
pipe = pipe.to(device)
|
| 237 |
+
vae = pipe.vae
|
| 238 |
+
tokenizer = pipe.tokenizer
|
| 239 |
+
text_encoder = pipe.text_encoder
|
| 240 |
+
unet = pipe.unet
|
| 241 |
+
scheduler = PNDMScheduler.from_config(pipe.scheduler.config)
|
| 242 |
+
except Exception as e:
|
| 243 |
+
return None, f"Error loading LoRA model from {lora_model_path} or base model from {base_model_path}: {e}"
|
| 244 |
+
else:
|
| 245 |
+
try:
|
| 246 |
+
vae = AutoencoderKL.from_pretrained(finetune_model_path, subfolder="vae", torch_dtype=dtype).to(device)
|
| 247 |
+
tokenizer = CLIPTokenizer.from_pretrained(finetune_model_path, subfolder="tokenizer")
|
| 248 |
+
text_encoder = CLIPTextModel.from_pretrained(finetune_model_path, subfolder="text_encoder", torch_dtype=dtype).to(device)
|
| 249 |
+
unet = UNet2DConditionModel.from_pretrained(finetune_model_path, subfolder="unet", torch_dtype=dtype).to(device)
|
| 250 |
+
scheduler = PNDMScheduler.from_pretrained(finetune_model_path, subfolder="scheduler")
|
| 251 |
+
except Exception as e:
|
| 252 |
+
return None, f"Error loading fine-tuned model from {finetune_model_path}: {e}"
|
| 253 |
+
|
| 254 |
+
text_input = tokenizer(
|
| 255 |
+
[prompt], padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt"
|
| 256 |
+
)
|
| 257 |
+
with torch.no_grad():
|
| 258 |
+
text_embeddings = text_encoder(text_input.input_ids.to(device))[0].to(dtype=dtype)
|
| 259 |
+
|
| 260 |
+
max_length = text_input.input_ids.shape[-1]
|
| 261 |
+
uncond_input = tokenizer(
|
| 262 |
+
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
|
| 263 |
+
)
|
| 264 |
+
with torch.no_grad():
|
| 265 |
+
uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0].to(dtype=dtype)
|
| 266 |
+
|
| 267 |
+
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
| 268 |
+
|
| 269 |
+
latents = torch.randn(
|
| 270 |
+
(batch_size, unet.config.in_channels, height // 8, width // 8),
|
| 271 |
+
generator=generator,
|
| 272 |
+
dtype=dtype,
|
| 273 |
+
device=device
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
scheduler.set_timesteps(num_inference_steps)
|
| 277 |
+
latents = latents * scheduler.init_noise_sigma
|
| 278 |
+
|
| 279 |
+
for t in tqdm(scheduler.timesteps, desc="Generating image"):
|
| 280 |
+
latent_model_input = torch.cat([latents] * 2)
|
| 281 |
+
latent_model_input = scheduler.scale_model_input(latent_model_input, t)
|
| 282 |
+
|
| 283 |
+
with torch.no_grad():
|
| 284 |
+
if device.type == "cuda":
|
| 285 |
+
with torch.autocast(device_type="cuda", dtype=torch.float16):
|
| 286 |
+
noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
| 287 |
+
else:
|
| 288 |
+
noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
| 289 |
+
|
| 290 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 291 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 292 |
+
latents = scheduler.step(noise_pred, t, latents).prev_sample
|
| 293 |
+
|
| 294 |
+
with torch.no_grad():
|
| 295 |
+
latents = latents / vae.config.scaling_factor
|
| 296 |
+
image = vae.decode(latents).sample
|
| 297 |
+
|
| 298 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
| 299 |
+
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
|
| 300 |
+
image = (image * 255).round().astype("uint8")
|
| 301 |
+
pil_image = Image.fromarray(image[0])
|
| 302 |
+
|
| 303 |
+
# Success message includes LoRA Path and LoRA Scale when use_lora is True
|
| 304 |
+
if use_lora:
|
| 305 |
+
return pil_image, f"Image generated successfully! Seed used: {seed}, LoRA Path: {lora_model_path}, LoRA Scale: {lora_scale}"
|
| 306 |
+
return pil_image, f"Image generated successfully! Seed used: {seed}"
|
| 307 |
+
|
| 308 |
+
def load_example_image(prompt, height, width, num_inference_steps, guidance_scale,
|
| 309 |
+
seed, image_path, use_lora, finetune_model_id, lora_model_id,
|
| 310 |
+
base_model_id, lora_rank, lora_scale):
|
| 311 |
+
"""
|
| 312 |
+
Load the image for the selected example and update input fields.
|
| 313 |
+
"""
|
| 314 |
+
if image_path and Path(image_path).exists():
|
| 315 |
+
try:
|
| 316 |
+
image = Image.open(image_path)
|
| 317 |
+
return (
|
| 318 |
+
prompt, height, width, num_inference_steps, guidance_scale, seed, image,
|
| 319 |
+
use_lora, finetune_model_id, lora_model_id, base_model_id, lora_rank, lora_scale,
|
| 320 |
+
f"Loaded image: {image_path}"
|
| 321 |
+
)
|
| 322 |
+
except Exception as e:
|
| 323 |
+
return (
|
| 324 |
+
prompt, height, width, num_inference_steps, guidance_scale, seed, None,
|
| 325 |
+
use_lora, finetune_model_id, lora_model_id, base_model_id, lora_rank, lora_scale,
|
| 326 |
+
f"Error loading image: {e}"
|
| 327 |
+
)
|
| 328 |
+
return (
|
| 329 |
+
prompt, height, width, num_inference_steps, guidance_scale, seed, None,
|
| 330 |
+
use_lora, finetune_model_id, lora_model_id, base_model_id, lora_rank, lora_scale,
|
| 331 |
+
"No image available"
|
| 332 |
+
)
|
| 333 |
+
|
| 334 |
+
badges_text = r"""
|
| 335 |
+
<div style="text-align: left; font-size: 14px; display: flex; flex-direction: column; gap: 10px;">
|
| 336 |
+
<div style="display: flex; align-items: center; justify-content: left; gap: 8px;">
|
| 337 |
+
You can explore GitHub repository:
|
| 338 |
+
<a href="https://github.com/danhtran2mind/Ghibli-Stable-Diffusion-Synthesis">
|
| 339 |
+
<img src="https://img.shields.io/badge/GitHub-danhtran2mind%2FGhibli--Stable--Diffusion--Synthesis-blue?style=flat&logo=github" alt="GitHub Repo">
|
| 340 |
+
</a>. And you can explore HuggingFace Model Hub:
|
| 341 |
+
<a href="https://huggingface.co/spaces/danhtran2mind/Ghibli-Stable-Diffusion-2.1-Base-finetuning">
|
| 342 |
+
<img src="https://img.shields.io/badge/HuggingFace-danhtran2mind%2FGhibli--Stable--Diffusion--2.1--Base--finetuning-yellow?style=flat&logo=huggingface" alt="HuggingFace Space Demo">
|
| 343 |
+
</a>
|
| 344 |
+
and
|
| 345 |
+
<a href="https://huggingface.co/spaces/danhtran2mind/Ghibli-Stable-Diffusion-2.1-LoRA">
|
| 346 |
+
<img src="https://img.shields.io/badge/HuggingFace-danhtran2mind%2FGhibli--Stable--Diffusion--2.1--LoRA-yellow?style=flat&logo=huggingface" alt="HuggingFace Space Demo">
|
| 347 |
+
</a>
|
| 348 |
+
</div>
|
| 349 |
+
</div>
|
| 350 |
+
""".strip()
|
| 351 |
+
|
| 352 |
+
with gr.Blocks() as demo:
|
| 353 |
+
# Main Layout: Split into Input and Output Columns
|
| 354 |
+
with gr.Row():
|
| 355 |
+
# Input Column
|
| 356 |
+
with gr.Column(scale=1):
|
| 357 |
+
gr.Markdown("## Image Generation Settings")
|
| 358 |
+
|
| 359 |
+
# Prompt Input
|
| 360 |
+
prompt = gr.Textbox(
|
| 361 |
+
label="Prompt",
|
| 362 |
+
placeholder="e.g., 'a serene landscape in Ghibli style'",
|
| 363 |
+
lines=2
|
| 364 |
+
)
|
| 365 |
+
|
| 366 |
+
# Image Dimensions
|
| 367 |
+
with gr.Group():
|
| 368 |
+
gr.Markdown("### Image Dimensions")
|
| 369 |
+
with gr.Row():
|
| 370 |
+
width = gr.Slider(
|
| 371 |
+
minimum=32,
|
| 372 |
+
maximum=4096,
|
| 373 |
+
value=512,
|
| 374 |
+
step=8,
|
| 375 |
+
label="Width"
|
| 376 |
+
)
|
| 377 |
+
height = gr.Slider(
|
| 378 |
+
minimum=32,
|
| 379 |
+
maximum=4096,
|
| 380 |
+
value=512,
|
| 381 |
+
step=8,
|
| 382 |
+
label="Height"
|
| 383 |
+
)
|
| 384 |
+
|
| 385 |
+
# Advanced Settings Accordion
|
| 386 |
+
with gr.Accordion("Advanced Settings", open=False):
|
| 387 |
+
num_inference_steps = gr.Slider(
|
| 388 |
+
minimum=1,
|
| 389 |
+
maximum=100,
|
| 390 |
+
value=50,
|
| 391 |
+
step=1,
|
| 392 |
+
label="Inference Steps",
|
| 393 |
+
info="Higher steps improve quality but increase generation time."
|
| 394 |
+
)
|
| 395 |
+
guidance_scale = gr.Slider(
|
| 396 |
+
minimum=1.0,
|
| 397 |
+
maximum=20.0,
|
| 398 |
+
value=3.5,
|
| 399 |
+
step=0.5,
|
| 400 |
+
label="Guidance Scale",
|
| 401 |
+
info="Controls how closely the image follows the prompt."
|
| 402 |
+
)
|
| 403 |
+
lora_rank = gr.Slider(
|
| 404 |
+
minimum=1,
|
| 405 |
+
maximum=128,
|
| 406 |
+
value=64,
|
| 407 |
+
step=1,
|
| 408 |
+
visible=False, # Initially hidden
|
| 409 |
+
label="LoRA Rank",
|
| 410 |
+
info="Controls model complexity and memory usage."
|
| 411 |
+
)
|
| 412 |
+
lora_scale = gr.Slider(
|
| 413 |
+
minimum=0.0,
|
| 414 |
+
maximum=2.0,
|
| 415 |
+
value=1.2,
|
| 416 |
+
step=0.1,
|
| 417 |
+
visible=False, # Initially hidden
|
| 418 |
+
label="LoRA Scale",
|
| 419 |
+
info="Adjusts the influence of LoRA weights."
|
| 420 |
+
)
|
| 421 |
+
random_seed = gr.Checkbox(
|
| 422 |
+
label="Use Random Seed",
|
| 423 |
+
value=False
|
| 424 |
+
)
|
| 425 |
+
seed = gr.Slider(
|
| 426 |
+
minimum=0,
|
| 427 |
+
maximum=4294967295,
|
| 428 |
+
value=42,
|
| 429 |
+
step=1,
|
| 430 |
+
label="Seed (0–4294967295)",
|
| 431 |
+
info="Set a specific seed for reproducible results."
|
| 432 |
+
)
|
| 433 |
+
|
| 434 |
+
# Model Selection
|
| 435 |
+
with gr.Group():
|
| 436 |
+
gr.Markdown("### Model Configuration")
|
| 437 |
+
use_lora = gr.Checkbox(
|
| 438 |
+
label="Use LoRA Weights",
|
| 439 |
+
value=False,
|
| 440 |
+
info="Enable to use LoRA weights with a base model."
|
| 441 |
+
)
|
| 442 |
+
|
| 443 |
+
# Model Path Dropdowns
|
| 444 |
+
finetune_model_path = gr.Dropdown(
|
| 445 |
+
label="Fine-tuned Model",
|
| 446 |
+
choices=finetune_model_ids,
|
| 447 |
+
value=finetune_model_id,
|
| 448 |
+
visible=not use_lora.value
|
| 449 |
+
)
|
| 450 |
+
lora_model_path = gr.Dropdown(
|
| 451 |
+
label="LoRA Model",
|
| 452 |
+
choices=lora_model_ids,
|
| 453 |
+
value=lora_model_id,
|
| 454 |
+
visible=use_lora.value
|
| 455 |
+
)
|
| 456 |
+
base_model_path = gr.Dropdown(
|
| 457 |
+
label="Base Model",
|
| 458 |
+
choices=base_model_ids,
|
| 459 |
+
value=base_model_id,
|
| 460 |
+
visible=use_lora.value
|
| 461 |
+
)
|
| 462 |
+
|
| 463 |
+
# Generate Button
|
| 464 |
+
generate_btn = gr.Button("Generate Image", variant="primary")
|
| 465 |
+
|
| 466 |
+
# Output Column
|
| 467 |
+
with gr.Column(scale=1):
|
| 468 |
+
gr.Markdown("## Generated Result")
|
| 469 |
+
output_image = gr.Image(
|
| 470 |
+
label="Generated Image",
|
| 471 |
+
interactive=False,
|
| 472 |
+
height=512
|
| 473 |
+
)
|
| 474 |
+
output_text = gr.Textbox(
|
| 475 |
+
label="Generation Status",
|
| 476 |
+
interactive=False,
|
| 477 |
+
lines=3
|
| 478 |
+
)
|
| 479 |
+
|
| 480 |
+
# Examples Section
|
| 481 |
+
gr.Markdown("## Try an Example")
|
| 482 |
+
examples = get_examples("apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning")
|
| 483 |
+
gr.Examples(
|
| 484 |
+
examples=examples,
|
| 485 |
+
inputs=[
|
| 486 |
+
prompt, height, width, num_inference_steps, guidance_scale, seed,
|
| 487 |
+
use_lora, finetune_model_path, lora_model_path, base_model_path,
|
| 488 |
+
lora_rank, lora_scale
|
| 489 |
+
],
|
| 490 |
+
outputs=[
|
| 491 |
+
prompt, height, width, num_inference_steps, guidance_scale, seed,
|
| 492 |
+
output_image, use_lora, finetune_model_path, lora_model_path,
|
| 493 |
+
base_model_path, lora_rank, lora_scale, output_text
|
| 494 |
+
],
|
| 495 |
+
fn=load_example_image,
|
| 496 |
+
cache_examples=False
|
| 497 |
+
)
|
| 498 |
+
|
| 499 |
+
# Event Handlers
|
| 500 |
+
use_lora.change(
|
| 501 |
+
fn=update_model_path_visibility,
|
| 502 |
+
inputs=use_lora,
|
| 503 |
+
outputs=[lora_model_path, base_model_path, finetune_model_path, lora_rank, lora_scale]
|
| 504 |
+
)
|
| 505 |
+
|
| 506 |
+
generate_btn.click(
|
| 507 |
+
fn=generate_image,
|
| 508 |
+
inputs=[
|
| 509 |
+
prompt, height, width, num_inference_steps, guidance_scale, seed,
|
| 510 |
+
random_seed, use_lora, finetune_model_path, lora_model_path,
|
| 511 |
+
base_model_path, lora_rank, lora_scale
|
| 512 |
+
],
|
| 513 |
+
outputs=[output_image, output_text]
|
| 514 |
+
)
|
| 515 |
+
|
| 516 |
+
return demo
|
| 517 |
+
|
| 518 |
+
if __name__ == "__main__":
|
| 519 |
+
parser = argparse.ArgumentParser(description="Ghibli-Style Image Generator using a fine-tuned Stable Diffusion model or Stable Diffusion 2.1 with LoRA weights.")
|
| 520 |
+
parser.add_argument(
|
| 521 |
+
"--config_path",
|
| 522 |
+
type=str,
|
| 523 |
+
default="configs/model_ckpts.yaml",
|
| 524 |
+
help="Path to the model configuration YAML file."
|
| 525 |
+
)
|
| 526 |
+
parser.add_argument(
|
| 527 |
+
"--device",
|
| 528 |
+
type=str,
|
| 529 |
+
default="cuda" if torch.cuda.is_available() else "cpu",
|
| 530 |
+
help="Device to run the model on (e.g., 'cuda', 'cpu')."
|
| 531 |
+
)
|
| 532 |
+
parser.add_argument(
|
| 533 |
+
"--port",
|
| 534 |
+
type=int,
|
| 535 |
+
default=7860,
|
| 536 |
+
help="Port to run the Gradio app on."
|
| 537 |
+
)
|
| 538 |
+
parser.add_argument(
|
| 539 |
+
"--share",
|
| 540 |
+
action="store_true",
|
| 541 |
+
default=False,
|
| 542 |
+
help="Set to True for public sharing (Hugging Face Spaces)."
|
| 543 |
+
)
|
| 544 |
+
|
| 545 |
+
args = parser.parse_args()
|
| 546 |
+
|
| 547 |
+
demo = create_demo(args.config_path, args.device)
|
| 548 |
+
demo.launch(server_port=args.port, share=args.share)
|
apps/old5-gradio_app.py
ADDED
|
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import json
|
| 3 |
+
from typing import Union, List
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
import os
|
| 6 |
+
import gradio as gr
|
| 7 |
+
import torch
|
| 8 |
+
from PIL import Image
|
| 9 |
+
import numpy as np
|
| 10 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
| 11 |
+
from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler, StableDiffusionPipeline
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
import yaml
|
| 14 |
+
|
| 15 |
+
def load_model_configs(config_path: str = "configs/model_ckpts.yaml") -> dict:
|
| 16 |
+
with open(config_path, 'r') as f:
|
| 17 |
+
return {cfg['model_id']: cfg for cfg in yaml.safe_load(f)}
|
| 18 |
+
|
| 19 |
+
def get_examples(examples_dir: Union[str, List[str]] = None, use_lora: bool = None) -> List:
|
| 20 |
+
directories = [examples_dir] if isinstance(examples_dir, str) else examples_dir or []
|
| 21 |
+
valid_dirs = [d for d in directories if os.path.isdir(d)]
|
| 22 |
+
if not valid_dirs:
|
| 23 |
+
return get_provided_examples(use_lora)
|
| 24 |
+
|
| 25 |
+
examples = []
|
| 26 |
+
for dir_path in valid_dirs:
|
| 27 |
+
for subdir in sorted(os.path.join(dir_path, d) for d in os.listdir(dir_path) if os.path.isdir(os.path.join(dir_path, d))):
|
| 28 |
+
config_path = os.path.join(subdir, "config.json")
|
| 29 |
+
image_path = os.path.join(subdir, "result.png")
|
| 30 |
+
if not (os.path.isfile(config_path) and os.path.isfile(image_path)):
|
| 31 |
+
continue
|
| 32 |
+
|
| 33 |
+
with open(config_path, 'r') as f:
|
| 34 |
+
config = json.load(f)
|
| 35 |
+
|
| 36 |
+
required_keys = ["prompt", "height", "width", "num_inference_steps", "guidance_scale", "seed", "image"]
|
| 37 |
+
if config.get("use_lora", False):
|
| 38 |
+
required_keys.extend(["lora_model_id", "base_model_id", "lora_rank", "lora_scale"])
|
| 39 |
+
else:
|
| 40 |
+
required_keys.append("finetune_model_id")
|
| 41 |
+
|
| 42 |
+
if set(required_keys) - set(config.keys()) or config["image"] != "result.png":
|
| 43 |
+
continue
|
| 44 |
+
|
| 45 |
+
try:
|
| 46 |
+
image = Image.open(image_path)
|
| 47 |
+
except Exception:
|
| 48 |
+
continue
|
| 49 |
+
|
| 50 |
+
if use_lora is not None and config.get("use_lora", False) != use_lora:
|
| 51 |
+
continue
|
| 52 |
+
|
| 53 |
+
example = [config["prompt"], config["height"], config["width"], config["num_inference_steps"],
|
| 54 |
+
config["guidance_scale"], config["seed"], image]
|
| 55 |
+
example.extend([config["lora_model_id"], config["base_model_id"], config["lora_rank"], config["lora_scale"]]
|
| 56 |
+
if config.get("use_lora", False) else [config["finetune_model_id"]])
|
| 57 |
+
examples.append(example)
|
| 58 |
+
|
| 59 |
+
return examples or get_provided_examples(use_lora)
|
| 60 |
+
|
| 61 |
+
def get_provided_examples(use_lora: bool = False) -> list:
|
| 62 |
+
example_path = f"apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-{'LoRA' if use_lora else 'Base-finetuning'}/1/result.png"
|
| 63 |
+
image = Image.open(example_path) if os.path.exists(example_path) else None
|
| 64 |
+
return [[
|
| 65 |
+
"a cat is laying on a sofa in Ghibli style" if use_lora else "a serene landscape in Ghibli style",
|
| 66 |
+
512, 768 if use_lora else 512, 100 if use_lora else 50, 10.0 if use_lora else 3.5, 789 if use_lora else 42,
|
| 67 |
+
image, "danhtran2mind/Ghibli-Stable-Diffusion-2.1-LoRA" if use_lora else "danhtran2mind/Ghibli-Stable-Diffusion-2.1-Base-finetuning",
|
| 68 |
+
"stabilityai/stable-diffusion-2-1" if use_lora else None, 64 if use_lora else None, 0.9 if use_lora else None
|
| 69 |
+
]]
|
| 70 |
+
|
| 71 |
+
def create_demo(config_path: str = "configs/model_ckpts.yaml", device: str = "cuda" if torch.cuda.is_available() else "cpu"):
|
| 72 |
+
model_configs = load_model_configs(config_path)
|
| 73 |
+
finetune_model_id = next((mid for mid, cfg in model_configs.items() if cfg.get('type') == 'full_finetuning'), None)
|
| 74 |
+
lora_model_id = next((mid for mid, cfg in model_configs.items() if cfg.get('type') == 'lora'), None)
|
| 75 |
+
|
| 76 |
+
if not finetune_model_id or not lora_model_id:
|
| 77 |
+
raise ValueError("Missing model IDs in config.")
|
| 78 |
+
|
| 79 |
+
finetune_model_path = model_configs[finetune_model_id].get('local_dir', finetune_model_id)
|
| 80 |
+
lora_model_path = model_configs[lora_model_id].get('local_dir', lora_model_id)
|
| 81 |
+
base_model_id = model_configs[lora_model_id].get('base_model_id', 'stabilityai/stable-diffusion-2-1')
|
| 82 |
+
base_model_path = model_configs.get(base_model_id, {}).get('local_dir', base_model_id)
|
| 83 |
+
|
| 84 |
+
device = torch.device(device)
|
| 85 |
+
dtype = torch.float16 if device.type == "cuda" else torch.float32
|
| 86 |
+
|
| 87 |
+
def generate_image(prompt, height, width, num_inference_steps, guidance_scale, seed, random_seed, use_lora,
|
| 88 |
+
finetune_model_id, lora_model_id, base_model_id, lora_rank, lora_scale):
|
| 89 |
+
if not prompt or height % 8 != 0 or width % 8 != 0 or num_inference_steps not in range(1, 101) or \
|
| 90 |
+
guidance_scale < 1.0 or guidance_scale > 20.0 or seed < 0 or seed > 4294967295 or \
|
| 91 |
+
(use_lora and (lora_rank < 1 or lora_rank > 128 or lora_scale < 0.0 or lora_scale > 2.0)):
|
| 92 |
+
return None, "Invalid input parameters."
|
| 93 |
+
|
| 94 |
+
model_configs = load_model_configs(config_path)
|
| 95 |
+
finetune_model_path = model_configs.get(finetune_model_id, {}).get('local_dir', finetune_model_id)
|
| 96 |
+
lora_model_path = model_configs.get(lora_model_id, {}).get('local_dir', lora_model_id)
|
| 97 |
+
base_model_path = model_configs.get(base_model_id, {}).get('local_dir', base_model_id)
|
| 98 |
+
|
| 99 |
+
generator = torch.Generator(device=device).manual_seed(torch.randint(0, 4294967295, (1,)).item() if random_seed else int(seed))
|
| 100 |
+
|
| 101 |
+
try:
|
| 102 |
+
if use_lora:
|
| 103 |
+
pipe = StableDiffusionPipeline.from_pretrained(base_model_path, torch_dtype=dtype, use_safetensors=True)
|
| 104 |
+
pipe.load_lora_weights(lora_model_path, adapter_name="ghibli-lora", lora_scale=lora_scale)
|
| 105 |
+
pipe = pipe.to(device)
|
| 106 |
+
vae, tokenizer, text_encoder, unet, scheduler = pipe.vae, pipe.tokenizer, pipe.text_encoder, pipe.unet, PNDMScheduler.from_config(pipe.scheduler.config)
|
| 107 |
+
else:
|
| 108 |
+
vae = AutoencoderKL.from_pretrained(finetune_model_path, subfolder="vae", torch_dtype=dtype).to(device)
|
| 109 |
+
tokenizer = CLIPTokenizer.from_pretrained(finetune_model_path, subfolder="tokenizer")
|
| 110 |
+
text_encoder = CLIPTextModel.from_pretrained(finetune_model_path, subfolder="text_encoder", torch_dtype=dtype).to(device)
|
| 111 |
+
unet = UNet2DConditionModel.from_pretrained(finetune_model_path, subfolder="unet", torch_dtype=dtype).to(device)
|
| 112 |
+
scheduler = PNDMScheduler.from_pretrained(finetune_model_path, subfolder="scheduler")
|
| 113 |
+
|
| 114 |
+
text_input = tokenizer([prompt], padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
|
| 115 |
+
text_embeddings = text_encoder(text_input.input_ids.to(device))[0].to(dtype=dtype)
|
| 116 |
+
|
| 117 |
+
uncond_input = tokenizer([""] * 1, padding="max_length", max_length=text_input.input_ids.shape[-1], return_tensors="pt")
|
| 118 |
+
uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0].to(dtype=dtype)
|
| 119 |
+
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
| 120 |
+
|
| 121 |
+
latents = torch.randn((1, unet.config.in_channels, height // 8, width // 8), generator=generator, dtype=dtype, device=device)
|
| 122 |
+
scheduler.set_timesteps(num_inference_steps)
|
| 123 |
+
latents = latents * scheduler.init_noise_sigma
|
| 124 |
+
|
| 125 |
+
for t in tqdm(scheduler.timesteps, desc="Generating image"):
|
| 126 |
+
latent_model_input = torch.cat([latents] * 2)
|
| 127 |
+
latent_model_input = scheduler.scale_model_input(latent_model_input, t)
|
| 128 |
+
noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
| 129 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 130 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 131 |
+
latents = scheduler.step(noise_pred, t, latents).prev_sample
|
| 132 |
+
|
| 133 |
+
image = vae.decode(latents / vae.config.scaling_factor).sample
|
| 134 |
+
image = (image / 2 + 0.5).clamp(0, 1).detach().cpu().permute(0, 2, 3, 1).numpy()
|
| 135 |
+
pil_image = Image.fromarray((image[0] * 255).round().astype("uint8"))
|
| 136 |
+
|
| 137 |
+
if use_lora:
|
| 138 |
+
del pipe
|
| 139 |
+
else:
|
| 140 |
+
del vae, tokenizer, text_encoder, unet, scheduler
|
| 141 |
+
torch.cuda.empty_cache()
|
| 142 |
+
|
| 143 |
+
return pil_image, f"Generated image successfully! Seed used: {seed}"
|
| 144 |
+
except Exception as e:
|
| 145 |
+
return None, f"Failed to generate image: {e}"
|
| 146 |
+
|
| 147 |
+
def load_example_image_full_finetuning(prompt, height, width, num_inference_steps, guidance_scale, seed, image, finetune_model_id):
|
| 148 |
+
return prompt, height, width, num_inference_steps, guidance_scale, seed, image, finetune_model_id, "Loaded example successfully"
|
| 149 |
+
|
| 150 |
+
def load_example_image_lora(prompt, height, width, num_inference_steps, guidance_scale, seed, image, lora_model_id, base_model_id, lora_rank, lora_scale):
|
| 151 |
+
return prompt, height, width, num_inference_steps, guidance_scale, seed, image, lora_model_id, base_model_id or "stabilityai/stable-diffusion-2-1", lora_rank or 64, lora_scale or 1.2, "Loaded example successfully"
|
| 152 |
+
|
| 153 |
+
badges_text = """
|
| 154 |
+
<div style="text-align: left; font-size: 14px; display: flex; flex-direction: column; gap: 10px;">
|
| 155 |
+
<div style="display: flex; align-items: center; justify-content: left; gap: 8px;">
|
| 156 |
+
GitHub: <a href="https://github.com/danhtran2mind/Ghibli-Stable-Diffusion-Synthesis">
|
| 157 |
+
<img src="https://img.shields.io/badge/GitHub-danhtran2mind%2FGhibli--Stable--Diffusion--Synthesis-blue?style=flat&logo=github" alt="GitHub Repo">
|
| 158 |
+
</a> HuggingFace:
|
| 159 |
+
<a href="https://huggingface.co/spaces/danhtran2mind/Ghibli-Stable-Diffusion-2.1-Base-finetuning">
|
| 160 |
+
<img src="https://img.shields.io/badge/HuggingFace-danhtran2mind%2FGhibli--Stable--Diffusion--2.1--Base--finetuning-yellow?style=flat&logo=huggingface" alt="HuggingFace Space Demo">
|
| 161 |
+
</a>
|
| 162 |
+
<a href="https://huggingface.co/spaces/danhtran2mind/Ghibli-Stable-Diffusion-2.1-LoRA">
|
| 163 |
+
<img src="https://img.shields.io/badge/HuggingFace-danhtran2mind%2FGhibli--Stable--Diffusion--2.1--LoRA-yellow?style=flat&logo=huggingface" alt="HuggingFace Space Demo">
|
| 164 |
+
</a>
|
| 165 |
+
</div>
|
| 166 |
+
</div>
|
| 167 |
+
"""
|
| 168 |
+
|
| 169 |
+
custom_css = open("apps/gradio_app/static/styles.css", "r").read() if os.path.exists("apps/gradio_app/static/styles.css") else ""
|
| 170 |
+
|
| 171 |
+
examples_full_finetuning = get_examples("apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning", use_lora=False)
|
| 172 |
+
examples_lora = get_examples("apps/gradio_app/assets/examples/Ghibli-Stable-Diffusion-2.1-LoRA", use_lora=True)
|
| 173 |
+
|
| 174 |
+
with gr.Blocks(css=custom_css, theme="ocean") as demo:
|
| 175 |
+
gr.Markdown("## Ghibli-Style Image Generator")
|
| 176 |
+
with gr.Tabs():
|
| 177 |
+
with gr.Tab(label="Full Finetuning"):
|
| 178 |
+
with gr.Row():
|
| 179 |
+
with gr.Column(scale=1):
|
| 180 |
+
gr.Markdown("### Image Generation Settings")
|
| 181 |
+
prompt_ft = gr.Textbox(label="Prompt", placeholder="e.g., 'a serene landscape in Ghibli style'", lines=2)
|
| 182 |
+
with gr.Group():
|
| 183 |
+
gr.Markdown("#### Image Dimensions")
|
| 184 |
+
with gr.Row():
|
| 185 |
+
width_ft = gr.Slider(32, 4096, 512, step=8, label="Width")
|
| 186 |
+
height_ft = gr.Slider(32, 4096, 512, step=8, label="Height")
|
| 187 |
+
with gr.Accordion("Advanced Settings", open=False):
|
| 188 |
+
num_inference_steps_ft = gr.Slider(1, 100, 50, step=1, label="Inference Steps")
|
| 189 |
+
guidance_scale_ft = gr.Slider(1.0, 20.0, 3.5, step=0.5, label="Guidance Scale")
|
| 190 |
+
random_seed_ft = gr.Checkbox(label="Use Random Seed")
|
| 191 |
+
seed_ft = gr.Slider(0, 4294967295, 42, step=1, label="Seed")
|
| 192 |
+
gr.Markdown("#### Model Configuration")
|
| 193 |
+
finetune_model_path_ft = gr.Dropdown(label="Fine-tuned Model", choices=[mid for mid, cfg in model_configs.items() if cfg.get('type') == 'full_finetuning'], value=finetune_model_id)
|
| 194 |
+
with gr.Column(scale=1):
|
| 195 |
+
gr.Markdown("### Generated Result")
|
| 196 |
+
output_image_ft = gr.Image(label="Generated Image", interactive=False, height=512)
|
| 197 |
+
output_text_ft = gr.Textbox(label="Status", interactive=False, lines=3)
|
| 198 |
+
generate_btn_ft = gr.Button("Generate Image", variant="primary")
|
| 199 |
+
stop_btn_ft = gr.Button("Stop Generation")
|
| 200 |
+
gr.Markdown("### Examples for Full Finetuning")
|
| 201 |
+
gr.Examples(examples=examples_full_finetuning, inputs=[prompt_ft, height_ft, width_ft, num_inference_steps_ft, guidance_scale_ft, seed_ft, output_image_ft, finetune_model_path_ft],
|
| 202 |
+
outputs=[prompt_ft, height_ft, width_ft, num_inference_steps_ft, guidance_scale_ft, seed_ft, output_image_ft, finetune_model_path_ft, output_text_ft],
|
| 203 |
+
fn=load_example_image_full_finetuning, cache_examples=False, examples_per_page=4)
|
| 204 |
+
|
| 205 |
+
with gr.Tab(label="LoRA"):
|
| 206 |
+
with gr.Row():
|
| 207 |
+
with gr.Column(scale=1):
|
| 208 |
+
gr.Markdown("### Image Generation Settings")
|
| 209 |
+
prompt_lora = gr.Textbox(label="Prompt", placeholder="e.g., 'a serene landscape in Ghibli style'", lines=2)
|
| 210 |
+
with gr.Group():
|
| 211 |
+
gr.Markdown("#### Image Dimensions")
|
| 212 |
+
with gr.Row():
|
| 213 |
+
width_lora = gr.Slider(32, 4096, 512, step=8, label="Width")
|
| 214 |
+
height_lora = gr.Slider(32, 4096, 512, step=8, label="Height")
|
| 215 |
+
with gr.Accordion("Advanced Settings", open=False):
|
| 216 |
+
num_inference_steps_lora = gr.Slider(1, 100, 50, step=1, label="Inference Steps")
|
| 217 |
+
guidance_scale_lora = gr.Slider(1.0, 20.0, 3.5, step=0.5, label="Guidance Scale")
|
| 218 |
+
lora_rank_lora = gr.Slider(1, 128, 64, step=1, label="LoRA Rank")
|
| 219 |
+
lora_scale_lora = gr.Slider(0.0, 2.0, 1.2, step=0.1, label="LoRA Scale")
|
| 220 |
+
random_seed_lora = gr.Checkbox(label="Use Random Seed")
|
| 221 |
+
seed_lora = gr.Slider(0, 4294967295, 42, step=1, label="Seed")
|
| 222 |
+
gr.Markdown("#### Model Configuration")
|
| 223 |
+
lora_model_path_lora = gr.Dropdown(label="LoRA Model", choices=[mid for mid, cfg in model_configs.items() if cfg.get('type') == 'lora'], value=lora_model_id)
|
| 224 |
+
base_model_path_lora = gr.Dropdown(label="Base Model", choices=[model_configs[mid].get('base_model_id') for mid in model_configs if model_configs[mid].get('base_model_id')], value=base_model_id)
|
| 225 |
+
with gr.Column(scale=1):
|
| 226 |
+
gr.Markdown("### Generated Result")
|
| 227 |
+
output_image_lora = gr.Image(label="Generated Image", interactive=False, height=512)
|
| 228 |
+
output_text_lora = gr.Textbox(label="Status", interactive=False, lines=3)
|
| 229 |
+
generate_btn_lora = gr.Button("Generate Image", variant="primary")
|
| 230 |
+
stop_btn_lora = gr.Button("Stop Generation")
|
| 231 |
+
gr.Markdown("### Examples for LoRA")
|
| 232 |
+
gr.Examples(examples=examples_lora, inputs=[prompt_lora, height_lora, width_lora, num_inference_steps_lora, guidance_scale_lora, seed_lora, output_image_lora, lora_model_path_lora, base_model_path_lora, lora_rank_lora, lora_scale_lora],
|
| 233 |
+
outputs=[prompt_lora, height_lora, width_lora, num_inference_steps_lora, guidance_scale_lora, seed_lora, output_image_lora, lora_model_path_lora, base_model_path_lora, lora_rank_lora, lora_scale_lora, output_text_lora],
|
| 234 |
+
fn=load_example_image_lora, cache_examples=False, examples_per_page=4)
|
| 235 |
+
|
| 236 |
+
gr.Markdown(badges_text)
|
| 237 |
+
|
| 238 |
+
generate_event_ft = generate_btn_ft.click(fn=generate_image, inputs=[prompt_ft, height_ft, width_ft, num_inference_steps_ft, guidance_scale_ft, seed_ft, random_seed_ft, gr.State(False), finetune_model_path_ft, gr.State(None), gr.State(None), gr.State(None), gr.State(None)],
|
| 239 |
+
outputs=[output_image_ft, output_text_ft])
|
| 240 |
+
generate_event_lora = generate_btn_lora.click(fn=generate_image, inputs=[prompt_lora, height_lora, width_lora, num_inference_steps_lora, guidance_scale_lora, seed_lora, random_seed_lora, gr.State(True), gr.State(None), lora_model_path_lora, base_model_path_lora, lora_rank_lora, lora_scale_lora],
|
| 241 |
+
outputs=[output_image_lora, output_text_lora])
|
| 242 |
+
|
| 243 |
+
stop_btn_ft.click(fn=None, inputs=None, outputs=None, cancels=[generate_event_ft])
|
| 244 |
+
stop_btn_lora.click(fn=None, inputs=None, outputs=None, cancels=[generate_event_lora])
|
| 245 |
+
|
| 246 |
+
demo.unload(lambda: torch.cuda.empty_cache())
|
| 247 |
+
|
| 248 |
+
return demo
|
| 249 |
+
|
| 250 |
+
if __name__ == "__main__":
|
| 251 |
+
parser = argparse.ArgumentParser(description="Ghibli-Style Image Generator")
|
| 252 |
+
parser.add_argument("--config_path", type=str, default="configs/model_ckpts.yaml")
|
| 253 |
+
parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
|
| 254 |
+
parser.add_argument("--port", type=int, default=7860)
|
| 255 |
+
parser.add_argument("--share", action="store_true")
|
| 256 |
+
args = parser.parse_args()
|
| 257 |
+
demo = create_demo(args.config_path, args.device)
|
| 258 |
+
demo.launch(server_port=args.port, share=args.share)
|
assets/.gitkeep
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
assets/demo_image.png
ADDED
|
Git LFS Details
|
assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning/1/config.json
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"prompt": "a serene landscape in Ghibli style",
|
| 3 |
+
"height": 256,
|
| 4 |
+
"width": 512,
|
| 5 |
+
"num_inference_steps": 50,
|
| 6 |
+
"guidance_scale": 3.5,
|
| 7 |
+
"seed": 42,
|
| 8 |
+
"image": "result.png",
|
| 9 |
+
"use_lora": false,
|
| 10 |
+
"finetune_model_id": "danhtran2mind/Ghibli-Stable-Diffusion-2.1-Base-finetuning"
|
| 11 |
+
}
|
assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning/1/result.png
ADDED
|
Git LFS Details
|
assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning/2/config.json
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"prompt": "Donald Trump",
|
| 3 |
+
"height": 512,
|
| 4 |
+
"width": 512,
|
| 5 |
+
"num_inference_steps": 100,
|
| 6 |
+
"guidance_scale": 9,
|
| 7 |
+
"seed": 200,
|
| 8 |
+
"image": "result.png",
|
| 9 |
+
"use_lora": false,
|
| 10 |
+
"finetune_model_id": "danhtran2mind/Ghibli-Stable-Diffusion-2.1-Base-finetuning"
|
| 11 |
+
}
|
assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning/2/result.png
ADDED
|
Git LFS Details
|
assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning/3/config.json
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"prompt": "a dancer in Ghibli style",
|
| 3 |
+
"height": 384,
|
| 4 |
+
"width": 192,
|
| 5 |
+
"num_inference_steps": 50,
|
| 6 |
+
"guidance_scale": 15.5,
|
| 7 |
+
"seed": 4223,
|
| 8 |
+
"image": "result.png",
|
| 9 |
+
"use_lora": false,
|
| 10 |
+
"finetune_model_id": "danhtran2mind/Ghibli-Stable-Diffusion-2.1-Base-finetuning"
|
| 11 |
+
}
|
assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning/3/result.png
ADDED
|
Git LFS Details
|
assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning/4/config.json
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"prompt": "Ghibli style, the peace beach",
|
| 3 |
+
"height": 1024,
|
| 4 |
+
"width": 2048,
|
| 5 |
+
"num_inference_steps": 100,
|
| 6 |
+
"guidance_scale": 7.5,
|
| 7 |
+
"seed": 5678,
|
| 8 |
+
"image": "result.png",
|
| 9 |
+
"use_lora": false,
|
| 10 |
+
"finetune_model_id": "danhtran2mind/Ghibli-Stable-Diffusion-2.1-Base-finetuning"
|
| 11 |
+
}
|
assets/examples/Ghibli-Stable-Diffusion-2.1-Base-finetuning/4/result.png
ADDED
|
Git LFS Details
|