Gemini899 commited on
Commit
214415a
·
verified ·
1 Parent(s): eff2e3e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -167
app.py CHANGED
@@ -1,64 +1,29 @@
1
- import os
2
- import time
3
- import numpy as np
4
- from PIL import Image
5
-
6
- # --- Hugging Face and Diffusers Imports ---
7
  import huggingface_hub
8
  # Monkey-patch: if cached_download is missing, alias it to hf_hub_download.
9
  if not hasattr(huggingface_hub, "cached_download"):
10
  huggingface_hub.cached_download = huggingface_hub.hf_hub_download
 
11
  print("huggingface_hub version:", huggingface_hub.__version__)
12
 
13
  import diffusers
14
  print("diffusers version:", diffusers.__version__)
15
  import numpy
16
  print("numpy version:", numpy.__version__)
 
17
  import gradio as gr
18
  import torch
19
  from diffusers import ControlNetModel, StableDiffusionXLControlNetImg2ImgPipeline, AutoencoderKL
 
 
 
 
 
 
20
 
21
- # --- Additional Utilities and Download Functions ---
22
- # (Assuming these are defined elsewhere in your project. For integration, dummy definitions are provided.)
23
- def dl_cn_model(cn_dir):
24
- # Dummy: in your project, this downloads the ControlNet model
25
- print("Downloading ControlNet model to", cn_dir)
26
- def dl_cn_config(cn_dir):
27
- print("Downloading ControlNet config to", cn_dir)
28
- def dl_tagger_model(tagger_dir):
29
- print("Downloading tagger model to", tagger_dir)
30
- def dl_lora_model(lora_dir):
31
- print("Downloading LoRA model to", lora_dir)
32
-
33
- # Dummy image utility functions
34
- def resize_image_aspect_ratio(image):
35
- # For demonstration, return the image as is.
36
- return image
37
-
38
- def base_generation(size, color):
39
- # Create a blank image with the given color and size.
40
- return Image.new("RGBA", size, color)
41
-
42
- # Dummy prompt utilities
43
- def execute_prompt(tags, prompt):
44
- # In your project, this may combine tags and prompt
45
- return prompt + " " + ", ".join(tags)
46
- def remove_color(prompt):
47
- # Dummy function: simply return prompt unchanged.
48
- return prompt
49
- def remove_duplicates(prompt):
50
- # Dummy function: simply return prompt unchanged.
51
- return prompt
52
-
53
- # Dummy tagger function (if needed)
54
- def modelLoad(tagger_dir):
55
- # Return a dummy model
56
- return None
57
- def analysis(image_path, tagger_dir, tagger_model):
58
- # Return dummy tags
59
- return ["lineart", "sketch"]
60
-
61
- # --- Set Up Directories ---
62
  path = os.getcwd()
63
  cn_dir = os.path.join(path, "controlnet")
64
  tagger_dir = os.path.join(path, "tagger")
@@ -67,43 +32,46 @@ os.makedirs(cn_dir, exist_ok=True)
67
  os.makedirs(tagger_dir, exist_ok=True)
68
  os.makedirs(lora_dir, exist_ok=True)
69
 
70
- # --- Download Required Models/Configs ---
71
  dl_cn_model(cn_dir)
72
  dl_cn_config(cn_dir)
73
  dl_tagger_model(tagger_dir)
74
  dl_lora_model(lora_dir)
75
 
76
- # --- Diffusers-Based Model Loading and Predict Function ---
77
  def load_model(lora_dir, cn_dir):
78
  dtype = torch.float16
79
- vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=dtype)
80
  controlnet = ControlNetModel.from_pretrained(cn_dir, torch_dtype=dtype, use_safetensors=True)
 
81
  pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(
82
- "cagliostrolab/animagine-xl-3.1", controlnet=controlnet, vae=vae, torch_dtype=dtype
83
  )
84
  pipe.enable_model_cpu_offload()
85
- # Load your LoRA weights (assumes the file "lineart.safetensors" is in lora_dir)
86
  pipe.load_lora_weights(lora_dir, weight_name="lineart.safetensors")
87
  return pipe
88
 
 
89
  def predict(input_image_path, prompt, negative_prompt, controlnet_scale):
90
- pipe = load_model(lora_dir, cn_dir)
91
- input_image = Image.open(input_image_path).convert("RGB")
92
  base_image = base_generation(input_image.size, (255, 255, 255, 255)).convert("RGB")
93
- resize_img = resize_image_aspect_ratio(input_image)
94
- resize_base = resize_image_aspect_ratio(base_image)
95
  generator = torch.manual_seed(0)
96
  last_time = time.time()
97
-
98
- # Here we assume the prompt is already generated (or modified) by the new Janus logic.
99
- final_prompt = prompt # Optionally, prepend additional base phrases if needed.
100
- print("Final prompt:", final_prompt)
 
 
 
101
 
102
  output_image = pipe(
103
- image=resize_base,
104
- control_image=resize_img,
105
  strength=1.0,
106
- prompt=final_prompt,
107
  negative_prompt=negative_prompt,
108
  controlnet_conditioning_scale=float(controlnet_scale),
109
  generator=generator,
@@ -114,91 +82,35 @@ def predict(input_image_path, prompt, negative_prompt, controlnet_scale):
114
  output_image = output_image.resize(input_image.size, Image.LANCZOS)
115
  return output_image
116
 
117
- # --- Janus-Based Prompt Generation Function ---
118
- # This code is taken from your second Hugging Face app.
119
- from transformers import AutoConfig, AutoModelForCausalLM
120
- from janus.models import VLChatProcessor
121
-
122
- # Set up Janus model and processor
123
- janus_model_path = "deepseek-ai/Janus-Pro-1B"
124
- janus_config = AutoConfig.from_pretrained(janus_model_path)
125
- language_config = janus_config.language_config
126
- language_config._attn_implementation = 'eager'
127
- vl_gpt = AutoModelForCausalLM.from_pretrained(
128
- janus_model_path,
129
- language_config=language_config,
130
- trust_remote_code=True
131
- )
132
- if torch.cuda.is_available():
133
- vl_gpt = vl_gpt.to(torch.bfloat16).cuda()
134
- else:
135
- vl_gpt = vl_gpt.to(torch.float16)
136
- vl_chat_processor = VLChatProcessor.from_pretrained(janus_model_path)
137
- tokenizer = vl_chat_processor.tokenizer
138
- cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu'
139
-
140
- @torch.inference_mode()
141
- @gr.analytics.track() # Using gradio decorator for potential GPU allocation if using spaces; you can remove if not needed.
142
- def generate_prompt_from_image(input_image_path, seed: int = 42, top_p: float = 0.95, temperature: float = 0.1):
143
- # Open image
144
- image = Image.open(input_image_path).convert("RGB")
145
- image_np = np.array(image)
146
-
147
- # Set seed for reproducibility
148
- torch.cuda.empty_cache()
149
- torch.manual_seed(seed)
150
- np.random.seed(seed)
151
- if torch.cuda.is_available():
152
- torch.cuda.manual_seed(seed)
153
-
154
- # Construct conversation for the Janus model
155
- conversation = [
156
- {
157
- "role": "<|User|>",
158
- "content": (
159
- "<image_placeholder>\nGenerate a detailed artistic prompt for extracting crisp, high-quality lineart. "
160
- "The prompt should include phrases like 'masterpiece, best quality, monochrome, sharp uniform black lines, "
161
- "vector style, very thick lineart, clean lineart, no shading, solid very thick black lines, no gradients, white background'."
162
- ),
163
- "images": [image_np],
164
- },
165
- {"role": "<|Assistant|>", "content": ""},
166
- ]
167
- pil_images = [image]
168
- prepare_inputs = vl_chat_processor(
169
- conversations=conversation, images=pil_images, force_batchify=True
170
- ).to(cuda_device, dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16)
171
- inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
172
-
173
- outputs = vl_gpt.language_model.generate(
174
- inputs_embeds=inputs_embeds,
175
- attention_mask=prepare_inputs.attention_mask,
176
- pad_token_id=tokenizer.eos_token_id,
177
- bos_token_id=tokenizer.bos_token_id,
178
- eos_token_id=tokenizer.eos_token_id,
179
- max_new_tokens=64,
180
- do_sample=True if temperature != 0 else False,
181
- use_cache=True,
182
- temperature=temperature,
183
- top_p=top_p,
184
- )
185
-
186
- answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
187
- return answer
188
 
189
- # --- Gradio Interface ---
190
  class Img2Img:
191
  def __init__(self):
192
  self.demo = self.layout()
193
-
194
- def generate_prompt_callback(self, input_image_path):
195
- # Generate prompt using Janus model
196
- return generate_prompt_from_image(input_image_path)
197
-
198
- def generate_image_callback(self, input_image_path, prompt, negative_prompt, controlnet_scale):
199
- # Use the provided prompt (which may have been auto-generated) to generate the image
200
- return predict(input_image_path, prompt, negative_prompt, controlnet_scale)
201
-
 
 
202
  def layout(self):
203
  css = """
204
  #intro{
@@ -210,34 +122,35 @@ class Img2Img:
210
  with gr.Blocks(css=css) as demo:
211
  with gr.Row():
212
  with gr.Column():
213
- input_image = gr.Image(label="Input Image", type="filepath")
214
- prompt = gr.Textbox(label="Prompt", lines=3)
215
- negative_prompt = gr.Textbox(
216
- label="Negative Prompt",
217
- lines=3,
218
- value="sketch, lowres, error, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, blurry"
219
- )
220
- controlnet_scale = gr.Slider(
221
- minimum=0.5, maximum=1.25, value=1.0, step=0.01, label="Lineart Fidelity"
222
  )
223
- generate_prompt_btn = gr.Button("Generate Prompt")
224
- generate_image_btn = gr.Button("Generate Image", variant="primary")
 
 
 
 
225
  with gr.Column():
226
- output_image = gr.Image(type="pil", label="Output Image")
227
-
228
- # Button callbacks
229
- generate_prompt_btn.click(
230
- fn=self.generate_prompt_callback,
231
- inputs=[input_image],
232
- outputs=prompt
233
  )
234
- generate_image_btn.click(
235
- fn=self.generate_image_callback,
236
- inputs=[input_image, prompt, negative_prompt, controlnet_scale],
237
- outputs=output_image
 
238
  )
239
  return demo
240
 
241
  img2img = Img2Img()
242
  img2img.demo.queue()
243
- img2img.demo.launch(share=True, show_error=True)
 
 
 
 
 
 
 
1
  import huggingface_hub
2
  # Monkey-patch: if cached_download is missing, alias it to hf_hub_download.
3
  if not hasattr(huggingface_hub, "cached_download"):
4
  huggingface_hub.cached_download = huggingface_hub.hf_hub_download
5
+
6
  print("huggingface_hub version:", huggingface_hub.__version__)
7
 
8
  import diffusers
9
  print("diffusers version:", diffusers.__version__)
10
  import numpy
11
  print("numpy version:", numpy.__version__)
12
+ import spaces
13
  import gradio as gr
14
  import torch
15
  from diffusers import ControlNetModel, StableDiffusionXLControlNetImg2ImgPipeline, AutoencoderKL
16
+ from PIL import Image
17
+ import os
18
+ import time
19
+
20
+ from utils.dl_utils import dl_cn_model, dl_cn_config, dl_tagger_model, dl_lora_model
21
+ from utils.image_utils import resize_image_aspect_ratio, base_generation
22
 
23
+ from utils.prompt_utils import execute_prompt, remove_color, remove_duplicates
24
+ from utils.tagger import modelLoad, analysis
25
+
26
+ # Set up directories
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  path = os.getcwd()
28
  cn_dir = os.path.join(path, "controlnet")
29
  tagger_dir = os.path.join(path, "tagger")
 
32
  os.makedirs(tagger_dir, exist_ok=True)
33
  os.makedirs(lora_dir, exist_ok=True)
34
 
35
+ # Download required models and configs
36
  dl_cn_model(cn_dir)
37
  dl_cn_config(cn_dir)
38
  dl_tagger_model(tagger_dir)
39
  dl_lora_model(lora_dir)
40
 
 
41
  def load_model(lora_dir, cn_dir):
42
  dtype = torch.float16
43
+ vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
44
  controlnet = ControlNetModel.from_pretrained(cn_dir, torch_dtype=dtype, use_safetensors=True)
45
+
46
  pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(
47
+ "cagliostrolab/animagine-xl-3.1", controlnet=controlnet, vae=vae, torch_dtype=torch.float16
48
  )
49
  pipe.enable_model_cpu_offload()
 
50
  pipe.load_lora_weights(lora_dir, weight_name="lineart.safetensors")
51
  return pipe
52
 
53
+ @spaces.GPU(duration=120)
54
  def predict(input_image_path, prompt, negative_prompt, controlnet_scale):
55
+ pipe = load_model(lora_dir, cn_dir)
56
+ input_image = Image.open(input_image_path)
57
  base_image = base_generation(input_image.size, (255, 255, 255, 255)).convert("RGB")
58
+ resize_image = resize_image_aspect_ratio(input_image)
59
+ resize_base_image = resize_image_aspect_ratio(base_image)
60
  generator = torch.manual_seed(0)
61
  last_time = time.time()
62
+ # Prepend a base prompt to get best results
63
+ prompt = "masterpiece, best quality, monochrome, sharp uniform black lines, vector style, very thick lineart, clean lineart, no shading, solid very thick black lines, no gradients, white background, " + prompt
64
+ execute_tags = ["sketch", "transparent background"]
65
+ prompt = execute_prompt(execute_tags, prompt)
66
+ prompt = remove_duplicates(prompt)
67
+ prompt = remove_color(prompt)
68
+ print(prompt)
69
 
70
  output_image = pipe(
71
+ image=resize_base_image,
72
+ control_image=resize_image,
73
  strength=1.0,
74
+ prompt=prompt,
75
  negative_prompt=negative_prompt,
76
  controlnet_conditioning_scale=float(controlnet_scale),
77
  generator=generator,
 
82
  output_image = output_image.resize(input_image.size, Image.LANCZOS)
83
  return output_image
84
 
85
+ @spaces.GPU(duration=120)
86
+ def prompt_analysis(input_image_path):
87
+ """
88
+ Run prompt analysis on the given image.
89
+ Loads the tagger model, runs analysis, cleans the tags, and returns a string.
90
+ """
91
+ # Load the tagger model using the tagger_dir (set earlier in the file)
92
+ tagger_model = modelLoad(tagger_dir)
93
+ tags = analysis(input_image_path, tagger_dir, tagger_model)
94
+ tags_clean = remove_color(tags)
95
+ if isinstance(tags_clean, (list, tuple)):
96
+ return ", ".join(tags_clean)
97
+ return tags_clean
98
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
 
100
  class Img2Img:
101
  def __init__(self):
102
  self.demo = self.layout()
103
+ self.tagger_model = None
104
+ self.input_image_path = None
105
+ self.canny_image = None
106
+
107
+ def process_prompt_analysis(self, input_image_path):
108
+ if self.tagger_model is None:
109
+ self.tagger_model = modelLoad(tagger_dir)
110
+ tags = analysis(input_image_path, tagger_dir, self.tagger_model)
111
+ tags_list = remove_color(tags)
112
+ return tags_list
113
+
114
  def layout(self):
115
  css = """
116
  #intro{
 
122
  with gr.Blocks(css=css) as demo:
123
  with gr.Row():
124
  with gr.Column():
125
+ self.input_image_path = gr.Image(label="Input image", type='filepath')
126
+ self.prompt = gr.Textbox(label="Prompt", lines=3)
127
+ self.negative_prompt = gr.Textbox(
128
+ label="Negative prompt",
129
+ lines=3,
130
+ value="sketch, lowres, error, extra digit, fewer digits, cropped, worst quality,low quality, normal quality, jpeg artifacts, blurry"
 
 
 
131
  )
132
+ # Button to run prompt analysis locally (UI callback)
133
+ prompt_analysis_button = gr.Button("Prompt analysis")
134
+ self.controlnet_scale = gr.Slider(
135
+ minimum=0.5, maximum=1.25, value=1.0, step=0.01, label="Lineart fidelity"
136
+ )
137
+ generate_button = gr.Button(value="Generate", variant="primary")
138
  with gr.Column():
139
+ self.output_image = gr.Image(type="pil", label="Output image")
140
+
141
+ prompt_analysis_button.click(
142
+ self.process_prompt_analysis,
143
+ inputs=[self.input_image_path],
144
+ outputs=self.prompt
 
145
  )
146
+
147
+ generate_button.click(
148
+ fn=predict,
149
+ inputs=[self.input_image_path, self.prompt, self.negative_prompt, self.controlnet_scale],
150
+ outputs=self.output_image
151
  )
152
  return demo
153
 
154
  img2img = Img2Img()
155
  img2img.demo.queue()
156
+ img2img.demo.launch(share=True, show_error=True)