videopix commited on
Commit
1703ae8
Β·
verified Β·
1 Parent(s): 3579187

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -13
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import spaces
2
  import torch
3
  from fastapi import FastAPI, Request
 
4
  from fastapi.responses import HTMLResponse, JSONResponse
5
  from diffusers import StableDiffusionXLPipeline
6
  from safetensors.torch import load_file
@@ -14,16 +15,16 @@ from concurrent.futures import ThreadPoolExecutor
14
  # Model Configuration
15
  # -----------------------------
16
  BASE_MODEL = "stabilityai/stable-diffusion-xl-base-1.0"
17
-
18
- # Multiple LoRA styles supported
19
  LORA_MODELS = {
20
  "Ghibli": "./studioghibli_flux_r32-v2.safetensors",
21
  "GH1bli": "./gh1bli-style.safetensors"
22
  }
23
 
24
  _cached_pipelines = {}
25
- executor = ThreadPoolExecutor(max_workers=4) # parallel generation threads
26
- semaphore = asyncio.Semaphore(4) # safely limit concurrent runs
 
 
27
 
28
  # -----------------------------
29
  # Pipeline Loader
@@ -32,10 +33,10 @@ def load_pipeline(style="Ghibli"):
32
  if style in _cached_pipelines:
33
  return _cached_pipelines[style]
34
 
 
35
  device = "cpu"
36
  dtype = torch.float32
37
 
38
- print(f"πŸ”Ή Loading SDXL model ({style}) on CPU...")
39
  pipe = StableDiffusionXLPipeline.from_pretrained(
40
  BASE_MODEL,
41
  dtype=dtype,
@@ -47,7 +48,7 @@ def load_pipeline(style="Ghibli"):
47
 
48
  model_path = LORA_MODELS.get(style, LORA_MODELS["Ghibli"])
49
  try:
50
- print(f"🎨 Applying LoRA: {model_path}")
51
  lora_weights = load_file(model_path)
52
  pipe.unet.load_state_dict(lora_weights, strict=False)
53
  print("βœ… LoRA loaded successfully.")
@@ -59,19 +60,17 @@ def load_pipeline(style="Ghibli"):
59
 
60
 
61
  # -----------------------------
62
- # Image Generation Function
63
  # -----------------------------
64
  def generate_image_sync(prompt: str, style: str = "Ghibli", seed: int = 42):
65
  pipe = load_pipeline(style)
66
  generator = torch.Generator(device="cpu").manual_seed(int(seed))
67
 
68
- # Enriched prompt to ensure full-frame, high-quality output
69
  enhanced_prompt = (
70
  f"{prompt}, Studio Ghibli style, full-frame composition, centered subject, "
71
- "no borders, clean background, cinematic tone, highly detailed illustration"
72
  )
73
 
74
- print(f"πŸ–ŒοΈ Generating ({style}): {enhanced_prompt}")
75
  image = pipe(
76
  prompt=enhanced_prompt,
77
  height=512,
@@ -81,7 +80,6 @@ def generate_image_sync(prompt: str, style: str = "Ghibli", seed: int = 42):
81
  generator=generator,
82
  ).images[0]
83
 
84
- # Crop to remove faint borders if any
85
  w, h = image.size
86
  image = image.crop((5, 5, w - 5, h - 5)).resize((512, 512))
87
  return image
@@ -99,7 +97,16 @@ async def generate_image_async(prompt, style, seed):
99
  # -----------------------------
100
  # FastAPI App Setup
101
  # -----------------------------
102
- app = FastAPI(title="Studio Ghibli Generator API")
 
 
 
 
 
 
 
 
 
103
 
104
 
105
  @app.get("/", response_class=HTMLResponse)
@@ -177,6 +184,7 @@ async def api_generate(request: Request):
177
  buffer = io.BytesIO()
178
  image.save(buffer, format="PNG")
179
  img_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
 
180
  return JSONResponse({
181
  "status": "success",
182
  "prompt": prompt,
@@ -198,6 +206,6 @@ def keep_alive():
198
 
199
  if __name__ == "__main__":
200
  import uvicorn
201
- print("πŸš€ Launching FastAPI (multi-style, CPU/ZeroGPU mode)")
202
  keep_alive()
203
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
  import spaces
2
  import torch
3
  from fastapi import FastAPI, Request
4
+ from fastapi.middleware.cors import CORSMiddleware
5
  from fastapi.responses import HTMLResponse, JSONResponse
6
  from diffusers import StableDiffusionXLPipeline
7
  from safetensors.torch import load_file
 
15
  # Model Configuration
16
  # -----------------------------
17
  BASE_MODEL = "stabilityai/stable-diffusion-xl-base-1.0"
 
 
18
  LORA_MODELS = {
19
  "Ghibli": "./studioghibli_flux_r32-v2.safetensors",
20
  "GH1bli": "./gh1bli-style.safetensors"
21
  }
22
 
23
  _cached_pipelines = {}
24
+ # Thread pool for parallel inference
25
+ executor = ThreadPoolExecutor(max_workers=6)
26
+ # Semaphore controls how many can run at once (to protect memory)
27
+ semaphore = asyncio.Semaphore(6)
28
 
29
  # -----------------------------
30
  # Pipeline Loader
 
33
  if style in _cached_pipelines:
34
  return _cached_pipelines[style]
35
 
36
+ print(f"πŸ”Ή Loading SDXL model for style: {style}")
37
  device = "cpu"
38
  dtype = torch.float32
39
 
 
40
  pipe = StableDiffusionXLPipeline.from_pretrained(
41
  BASE_MODEL,
42
  dtype=dtype,
 
48
 
49
  model_path = LORA_MODELS.get(style, LORA_MODELS["Ghibli"])
50
  try:
51
+ print(f"🎨 Applying LoRA weights: {model_path}")
52
  lora_weights = load_file(model_path)
53
  pipe.unet.load_state_dict(lora_weights, strict=False)
54
  print("βœ… LoRA loaded successfully.")
 
60
 
61
 
62
  # -----------------------------
63
+ # Synchronous Image Generation
64
  # -----------------------------
65
  def generate_image_sync(prompt: str, style: str = "Ghibli", seed: int = 42):
66
  pipe = load_pipeline(style)
67
  generator = torch.Generator(device="cpu").manual_seed(int(seed))
68
 
 
69
  enhanced_prompt = (
70
  f"{prompt}, Studio Ghibli style, full-frame composition, centered subject, "
71
+ "clean background, cinematic tone, detailed illustration, digital painting"
72
  )
73
 
 
74
  image = pipe(
75
  prompt=enhanced_prompt,
76
  height=512,
 
80
  generator=generator,
81
  ).images[0]
82
 
 
83
  w, h = image.size
84
  image = image.crop((5, 5, w - 5, h - 5)).resize((512, 512))
85
  return image
 
97
  # -----------------------------
98
  # FastAPI App Setup
99
  # -----------------------------
100
+ app = FastAPI(title="Studio Ghibli Generator API", version="2.0")
101
+
102
+ # Enable cross-app requests (CORS)
103
+ app.add_middleware(
104
+ CORSMiddleware,
105
+ allow_origins=["*"], # Allow any domain (you can restrict if needed)
106
+ allow_credentials=True,
107
+ allow_methods=["*"],
108
+ allow_headers=["*"],
109
+ )
110
 
111
 
112
  @app.get("/", response_class=HTMLResponse)
 
184
  buffer = io.BytesIO()
185
  image.save(buffer, format="PNG")
186
  img_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
187
+
188
  return JSONResponse({
189
  "status": "success",
190
  "prompt": prompt,
 
206
 
207
  if __name__ == "__main__":
208
  import uvicorn
209
+ print("πŸš€ Launching FastAPI (Multi-Request / CPU / ZeroGPU Mode)")
210
  keep_alive()
211
  uvicorn.run(app, host="0.0.0.0", port=7860)