| import os |
| import spaces |
| from typing import Optional |
|
|
| import gradio as gr |
| import numpy as np |
| import pandas as pd |
| import torch |
| from PIL import Image |
|
|
| from pipeline_interpolated_sd import InterpolationStableDiffusionPipeline |
| from pipeline_interpolated_sdxl import InterpolationStableDiffusionXLPipeline |
| from prior import BetaPriorPipeline |
|
|
|
|
| os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
| title = r""" |
| <h1 align="center">PAID: (Prompt-guided) Attention Interpolation of Text-to-Image Diffusion</h1> |
| """ |
|
|
| description = r""" |
| <b>Official 🤗 Gradio demo</b> for <a href='https://github.com/QY-H00/attention-interpolation-diffusion/tree/public' target='_blank'><b>PAID: (Prompt-guided) Attention Interpolation of Text-to-Image Diffusion</b></a>.<br> |
| We strongly recommend to deploy this demo locally with GPU resources for better experience, especially for XL-series since the GPU resources is limited in this space. <br> |
| How to use:<br> |
| 1. Input prompt 1, prompt 2 and negative prompt. |
| 2. For <b> Compositional Generation </b> Input the guidance prompt and choose the one you are satisfied! |
| 3. For <b> Image morphing </b> Input the image prompt 1 and image prompt 2, and choose IP-Adapter. |
| 4. For <b> Scale Control </b> Input the same text for prompt 1 and prompt 2, leave image prompt 1 blank and upload image prompt 2. Then choose IP-Adapter or IP-Composition-Adapter. |
| 5. <b> Note that the time required for the SD-series with an exploration size of 10 is around 120 seconds. XL-series with an exploration size 5 is around 5 minutes 30 seconds. </b> |
| 6. Click the <b>Generate</b> button to begin generating images. |
| 7. Enjoy! 😊""" |
|
|
| article = r""" |
| --- |
| ✒️ **Citation** |
| <br> |
| If you found this demo/our paper useful, please consider citing: |
| ```bibtex |
| @article{he2024aid, |
| title={AID: Attention Interpolation of Text-to-Image Diffusion}, |
| author={He, Qiyuan and Wang, Jinghao and Liu, Ziwei and Yao, Angela}, |
| journal={arXiv preprint arXiv:2403.17924}, |
| year={2024} |
| } |
| ``` |
| 📧 **Contact** |
| <br> |
| If you have any questions, please feel free to open an issue in our <a href='https://github.com/QY-H00/attention-interpolation-diffusion/tree/public' target='_blank'><b>Github Repo</b></a> or directly reach us out at <b>qhe@u.nus.edu.sg</b>. |
| """ |
|
|
| MAX_SEED = np.iinfo(np.int32).max |
| CACHE_EXAMPLES = False |
| USE_TORCH_COMPILE = False |
| ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD") == "1" |
| PREVIEW_IMAGES = False |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| pipeline = InterpolationStableDiffusionPipeline.from_pretrained( |
| "SG161222/Realistic_Vision_V4.0_noVAE", |
| torch_dtype=torch.float16 |
| ) |
| pipeline.to(device, dtype=torch.float16) |
|
|
|
|
| def change_model_fn(model_name: str) -> None: |
| global device |
| name_mapping = { |
| "AOM3": "hogiahien/aom3", |
| "SD1.5-512": "stable-diffusion-v1-5/stable-diffusion-v1-5", |
| "SD2.1-768": "stabilityai/stable-diffusion-2-1", |
| "RealVis-v4.0": "SG161222/Realistic_Vision_V4.0_noVAE", |
| "SDXL-1024": "stabilityai/stable-diffusion-xl-base-1.0", |
| "Playground-XL-v2": "playgroundai/playground-v2.5-1024px-aesthetic", |
| "Juggernaut-XL-v9": "RunDiffusion/Juggernaut-XL-v9" |
| } |
| if device == torch.device("cpu"): |
| dtype = torch.float16 |
| else: |
| dtype = torch.float16 |
| if "XL" not in model_name: |
| globals()["pipeline"] = InterpolationStableDiffusionPipeline.from_pretrained( |
| name_mapping[model_name], torch_dtype=dtype |
| ) |
| globals()["pipeline"].to(device, dtype=torch.float16) |
| else: |
| globals()["pipeline"] = InterpolationStableDiffusionXLPipeline.from_pretrained( |
| name_mapping[model_name], torch_dtype=dtype |
| ) |
| globals()["pipeline"].to(device) |
|
|
|
|
| def change_adapter_fn(adapter_name: str) -> None: |
| global pipeline |
| if adapter_name == "IP-Adapter": |
| if isinstance(pipeline, InterpolationStableDiffusionPipeline): |
| pipeline.load_aid_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin") |
| else: |
| pipeline.load_aid_ip_adapter("ozzygt/sdxl-ip-adapter", "", weight_name="ip-adapter-plus_sdxl_vit-h.safetensors") |
| elif adapter_name == "IP-Composition-Adapter": |
| if isinstance(pipeline, InterpolationStableDiffusionPipeline): |
| pipeline.load_aid_ip_adapter("ostris/ip-composition-adapter", subfolder="", weight_name="ip_plus_composition_sd15.safetensors") |
| else: |
| pipeline.load_aid_ip_adapter("ozzygt/sdxl-ip-adapter", subfolder="", weight_name="ip_plus_composition_sdxl.safetensors") |
| else: |
| pipeline.load_aid() |
|
|
|
|
| def save_image(img, index): |
| unique_name = f"{index}.png" |
| img = Image.fromarray(img) |
| img.save(unique_name) |
| return unique_name |
|
|
|
|
| def get_example() -> list[list[str | float | int ]]: |
| case = [ |
| [ |
| "A statue", |
| "A dragon", |
| "nsfw, lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]", |
| "", |
| None, |
| None, |
| 50, |
| 10, |
| 5, |
| 5.0, |
| 0.5, |
| "RealVis-v4.0", |
| "None", |
| 0, |
| True, |
| ], |
| [ |
| "A photo of a statue", |
| "Het meisje met de parel, by Vermeer", |
| "nsfw, lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]", |
| "", |
| Image.open("asset/statue.jpg"), |
| Image.open("asset/vermeer.jpg"), |
| 50, |
| 10, |
| 5, |
| 5.0, |
| 0.5, |
| "RealVis-v4.0", |
| "IP-Adapter", |
| 0, |
| True, |
| ], |
| [ |
| "A boy is smiling", |
| "A boy is smiling", |
| "nsfw, lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]", |
| "", |
| None, |
| Image.open("asset/vermeer.jpg"), |
| 50, |
| 10, |
| 5, |
| 5.0, |
| 0.5, |
| "RealVis-v4.0", |
| "IP-Composition-Adapter", |
| 0, |
| True, |
| ], |
| [ |
| "masterpiece, best quality, very aesthetic, absurdres, A dog", |
| "masterpiece, best quality, very aesthetic, absurdres, A car", |
| "nsfw, lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]", |
| "masterpiece, best quality, very aesthetic, absurdres, the toy, named 'Dog-Car', is designed as a dog figure with car wheels instead of feet", |
| None, |
| None, |
| 50, |
| 5, |
| 5, |
| 5.0, |
| 0.5, |
| "RealVis-v4.0", |
| "None", |
| 1002, |
| True |
| ], |
| [ |
| "masterpiece, best quality, very aesthetic, absurdres, A dog", |
| "masterpiece, best quality, very aesthetic, absurdres, A car", |
| "nsfw, lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]", |
| "masterpiece, best quality, very aesthetic, absurdres, a dog is driving a car", |
| None, |
| None, |
| 28, |
| 5, |
| 5, |
| 5.0, |
| 0.5, |
| "Playground-XL-v2", |
| "None", |
| 1002, |
| True |
| ] |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| ] |
| return case |
|
|
|
|
| def change_generate_button_fn(enable: int) -> gr.Button: |
| if enable == 0: |
| return gr.Button(interactive=False, value="Switching Model...") |
| else: |
| return gr.Button(interactive=True, value="Generate") |
|
|
|
|
| def dynamic_gallery_fn(interpolation_size: int): |
| return gr.Gallery( |
| label="Result", show_label=False, rows=1, columns=interpolation_size |
| ) |
|
|
| @spaces.GPU(duration=110) |
| @torch.no_grad() |
| def generate( |
| prompt1, |
| prompt2, |
| negative_prompt, |
| guide_prompt=None, |
| image_prompt1=None, |
| image_prompt2=None, |
| num_inference_steps=28, |
| exploration_size=16, |
| interpolation_size=7, |
| guidance_scale=5.0, |
| warmup_ratio=0.5, |
| seed=0, |
| same_latent=True, |
| ) -> np.ndarray: |
| global pipeline |
| global adapter_choice |
| beta_pipe = BetaPriorPipeline(pipeline) |
| if guide_prompt == "": |
| guide_prompt = None |
| generator = ( |
| torch.cuda.manual_seed(seed) |
| if torch.cuda.is_available() |
| else torch.manual_seed(seed) |
| ) |
| size = pipeline.unet.config.sample_size |
| latent1 = torch.randn((1, 4, size, size,), device="cuda", dtype=pipeline.unet.dtype, generator=generator) |
| if same_latent: |
| latent2 = latent1.clone() |
| else: |
| latent2 = torch.randn((1, 4, size, size,), device="cuda", dtype=pipeline.unet.dtype, generator=generator) |
|
|
| if image_prompt1 is None and image_prompt2 is None: |
| pipeline.load_aid() |
| elif (image_prompt1 is None and image_prompt2 is not None): |
| if adapter_choice.value == "IP-Adapter": |
| if isinstance(pipeline, InterpolationStableDiffusionPipeline): |
| pipeline.load_aid_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin") |
| else: |
| pipeline.load_aid_ip_adapter("ozzygt/sdxl-ip-adapter", "", weight_name="ip-adapter-plus_sdxl_vit-h.safetensors") |
| elif adapter_choice.value == "IP-Composition-Adapter": |
| if isinstance(pipeline, InterpolationStableDiffusionPipeline): |
| pipeline.load_aid_ip_adapter("ostris/ip-composition-adapter", subfolder="", weight_name="ip_plus_composition_sd15.safetensors") |
| else: |
| pipeline.load_aid_ip_adapter("ozzygt/sdxl-ip-adapter", subfolder="", weight_name="ip_plus_composition_sdxl.safetensors") |
| elif (image_prompt1 is None and image_prompt2 is not None): |
| if adapter_choice.value == "IP-Adapter": |
| if isinstance(pipeline, InterpolationStableDiffusionPipeline): |
| pipeline.load_aid_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin", early="scale_control") |
| else: |
| pipeline.load_aid_ip_adapter("ozzygt/sdxl-ip-adapter", "", weight_name="ip-adapter-plus_sdxl_vit-h.safetensors", early="scale_control") |
| elif adapter_choice.value == "IP-Composition-Adapter": |
| if isinstance(pipeline, InterpolationStableDiffusionPipeline): |
| pipeline.load_aid_ip_adapter("ostris/ip-composition-adapter", subfolder="", weight_name="ip_plus_composition_sd15.safetensors", early="scale_control") |
| else: |
| pipeline.load_aid_ip_adapter("ozzygt/sdxl-ip-adapter", subfolder="", weight_name="ip_plus_composition_sdxl.safetensors", early="scale_control") |
| else: |
| raise ValueError("To use scale control, please provide only the right image; To use image morphing, please provide images from both side.") |
| images = beta_pipe.generate_interpolation( |
| gr.Progress(), |
| prompt1, |
| prompt2, |
| negative_prompt, |
| latent1, |
| latent2, |
| num_inference_steps, |
| image_start=image_prompt1, |
| image_end=image_prompt2, |
| exploration_size=exploration_size, |
| interpolation_size=interpolation_size, |
| output_type="np", |
| guide_prompt=guide_prompt, |
| guidance_scale=guidance_scale, |
| warmup_ratio=warmup_ratio |
| ) |
| return images |
|
|
|
|
| interpolation_size = None |
|
|
| with gr.Blocks(css="style.css") as demo: |
| gr.Markdown(title) |
| gr.Markdown(description) |
| with gr.Group(): |
| with gr.Column(): |
| with gr.Row(): |
| prompt1 = gr.Text( |
| label="Prompt 1", |
| max_lines=3, |
| placeholder="Enter the First Prompt", |
| interactive=True, |
| value="A statue", |
| ) |
| prompt2 = gr.Text( |
| label="Prompt 2", |
| max_lines=3, |
| placeholder="Enter the Second Prompt", |
| interactive=True, |
| value="A dragon", |
| ) |
| negative_prompt = gr.Text( |
| label="Negative prompt", |
| max_lines=3, |
| placeholder="Enter a Negative Prompt", |
| interactive=True, |
| value="nsfw, lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract]", |
| ) |
| guidance_prompt = gr.Text( |
| label="Guidance prompt (Optional)", |
| max_lines=3, |
| placeholder="Enter a Guidance Prompt", |
| interactive=True, |
| value="", |
| ) |
| model_choice = gr.Dropdown( |
| ["RealVis-v4.0", "SD1.4-512", "SD1.5-512", "SD2.1-768", "AOM3", "SDXL-1024", "Playground-XL-v2", "Juggernaut-XL-v9"], |
| label="Model", |
| value="RealVis-v4.0", |
| interactive=True, |
| info="All series are running on float16; SD2.1 does not support IP-Adapter; XL-Series takes longer time", |
| ) |
|
|
| with gr.Group(): |
| with gr.Row(): |
| image_prompt1 = gr.Image(label="Image Prompt 1 (Optional)", interactive=True) |
| image_prompt2 = gr.Image(label="Image Prompt 2 (Optional)", interactive=True) |
| adapter_choice = gr.Dropdown( |
| ["None", "IP-Adapter", "IP-Composition-Adapter"], |
| label="IP-Adapter", |
| value="None", |
| interactive=True, |
| info="Only set to IP-Adapter or IP-Composition-Adapter when using image prompt", |
| ) |
|
|
| with gr.Group(): |
| result = gr.Gallery(label="Result", show_label=False, rows=1, columns=3) |
| generate_button = gr.Button(value="Generate", variant="primary") |
|
|
| with gr.Accordion("Advanced options", open=True): |
| with gr.Group(): |
| with gr.Row(): |
| with gr.Column(): |
| interpolation_size = gr.Slider( |
| label="Interpolation Size", |
| minimum=3, |
| maximum=5, |
| step=1, |
| value=3, |
| info="Interpolation size includes the start and end images", |
| ) |
| exploration_size = gr.Slider( |
| label="Exploration Size", |
| minimum=5, |
| maximum=10, |
| step=1, |
| value=5, |
| info="Exploration size has to be larger than interpolation size", |
| ) |
| with gr.Row(): |
| with gr.Column(): |
| warmup_ratio = gr.Slider( |
| label="Warmup Ratio", |
| minimum=0.02, |
| maximum=1, |
| step=0.01, |
| value=0.5, |
| interactive=True, |
| ) |
| guidance_scale = gr.Slider( |
| label="Guidance Scale", |
| minimum=0, |
| maximum=20, |
| step=0.1, |
| value=5.0, |
| interactive=True, |
| ) |
| num_inference_steps = gr.Slider( |
| label="Inference Steps", |
| minimum=25, |
| maximum=50, |
| step=1, |
| value=50, |
| interactive=True, |
| ) |
| with gr.Column(): |
| seed = gr.Slider( |
| label="Seed", |
| minimum=0, |
| maximum=MAX_SEED, |
| step=1, |
| value=0, |
| ) |
| same_latent = gr.Checkbox( |
| label="Same latent", |
| value=False, |
| info="Use the same latent for start and end images", |
| show_label=True, |
| ) |
|
|
| gr.Examples( |
| examples=get_example(), |
| inputs=[ |
| prompt1, |
| prompt2, |
| negative_prompt, |
| guidance_prompt, |
| image_prompt1, |
| image_prompt2, |
| num_inference_steps, |
| exploration_size, |
| interpolation_size, |
| guidance_scale, |
| warmup_ratio, |
| model_choice, |
| adapter_choice, |
| seed, |
| same_latent, |
| ], |
| cache_examples=CACHE_EXAMPLES, |
| ) |
|
|
| model_choice.change( |
| fn=change_generate_button_fn, |
| inputs=gr.Number(0, visible=False), |
| outputs=generate_button, |
| ).then(fn=change_model_fn, inputs=model_choice).then( |
| fn=change_generate_button_fn, |
| inputs=gr.Number(1, visible=False), |
| outputs=generate_button, |
| ) |
|
|
| adapter_choice.change( |
| fn=change_generate_button_fn, |
| inputs=gr.Number(0, visible=False), |
| outputs=generate_button, |
| ).then(fn=change_adapter_fn, inputs=[adapter_choice]).then( |
| fn=change_generate_button_fn, |
| inputs=gr.Number(1, visible=False), |
| outputs=generate_button, |
| ) |
|
|
| inputs = [ |
| prompt1, |
| prompt2, |
| negative_prompt, |
| guidance_prompt, |
| image_prompt1, |
| image_prompt2, |
| num_inference_steps, |
| exploration_size, |
| interpolation_size, |
| guidance_scale, |
| warmup_ratio, |
| seed, |
| same_latent, |
| ] |
| generate_button.click( |
| fn=dynamic_gallery_fn, |
| inputs=interpolation_size, |
| outputs=result, |
| ).then( |
| fn=generate, |
| inputs=inputs, |
| outputs=result, |
| ) |
| gr.Markdown(article) |
|
|
| demo.launch() |
|
|