videopix commited on
Commit
a30b23b
·
verified ·
1 Parent(s): ef9be71

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -61
app.py CHANGED
@@ -1,10 +1,10 @@
1
  import spaces
2
  import torch
 
3
  from fastapi import FastAPI, Request
4
  from fastapi.middleware.cors import CORSMiddleware
5
  from fastapi.responses import HTMLResponse, JSONResponse
6
- from diffusers import StableDiffusionXLPipeline
7
- from safetensors.torch import load_file
8
  from PIL import Image
9
  import base64
10
  import io
@@ -12,97 +12,80 @@ import asyncio
12
  from concurrent.futures import ThreadPoolExecutor
13
 
14
  # -----------------------------
15
- # Model Configuration
16
  # -----------------------------
17
- BASE_MODEL = "stabilityai/stable-diffusion-xl-base-1.0"
18
- LORA_MODELS = {
19
- "Ghibli": "./studioghibli_flux_r32-v2.safetensors",
20
- "GH1bli": "./gh1bli-style.safetensors"
21
- }
 
22
 
23
  _cached_pipelines = {}
24
- # Thread pool for parallel inference
25
- executor = ThreadPoolExecutor(max_workers=6)
26
- # Semaphore controls how many can run at once (to protect memory)
27
- semaphore = asyncio.Semaphore(6)
28
 
29
  # -----------------------------
30
- # Pipeline Loader
31
  # -----------------------------
32
- def load_pipeline(style="Ghibli"):
33
- if style in _cached_pipelines:
34
- return _cached_pipelines[style]
35
 
36
- print(f"🔹 Loading SDXL model for style: {style}")
37
- device = "cpu"
38
- dtype = torch.float32
39
 
40
- pipe = StableDiffusionXLPipeline.from_pretrained(
41
  BASE_MODEL,
42
- dtype=dtype,
43
- use_safetensors=True,
44
- ).to(device)
45
 
46
  pipe.enable_attention_slicing()
47
  pipe.enable_vae_tiling()
48
 
49
- model_path = LORA_MODELS.get(style, LORA_MODELS["Ghibli"])
50
- try:
51
- print(f"🎨 Applying LoRA weights: {model_path}")
52
- lora_weights = load_file(model_path)
53
- pipe.unet.load_state_dict(lora_weights, strict=False)
54
- print("✅ LoRA loaded successfully.")
55
- except Exception as e:
56
- print(f"⚠️ LoRA load failed: {e}")
57
-
58
- _cached_pipelines[style] = pipe
59
  return pipe
60
 
61
 
62
  # -----------------------------
63
- # Synchronous Image Generation
64
  # -----------------------------
65
- def generate_image_sync(prompt: str, style: str = "Ghibli", seed: int = 42):
66
- pipe = load_pipeline(style)
67
  generator = torch.Generator(device="cpu").manual_seed(int(seed))
68
 
69
- enhanced_prompt = (
70
- f"{prompt}, Studio Ghibli style, full-frame composition, centered subject, "
71
- "clean background, cinematic tone, detailed illustration, digital painting"
72
- )
73
 
74
  image = pipe(
75
- prompt=enhanced_prompt,
76
- height=512,
77
- width=512,
78
- num_inference_steps=30,
79
- guidance_scale=7.0,
80
  generator=generator,
81
  ).images[0]
82
 
83
- w, h = image.size
84
- image = image.crop((5, 5, w - 5, h - 5)).resize((512, 512))
85
  return image
86
 
87
 
88
  # -----------------------------
89
- # Async Wrapper for Concurrency
90
  # -----------------------------
91
  async def generate_image_async(prompt, style, seed):
92
  async with semaphore:
93
- loop = asyncio.get_event_loop()
94
  return await loop.run_in_executor(executor, generate_image_sync, prompt, style, seed)
95
 
96
 
97
  # -----------------------------
98
  # FastAPI App Setup
99
  # -----------------------------
100
- app = FastAPI(title="Studio Ghibli Generator API", version="2.0")
101
 
102
- # Enable cross-app requests (CORS)
103
  app.add_middleware(
104
  CORSMiddleware,
105
- allow_origins=["*"], # Allow any domain (you can restrict if needed)
106
  allow_credentials=True,
107
  allow_methods=["*"],
108
  allow_headers=["*"],
@@ -114,22 +97,21 @@ def home():
114
  return """
115
  <html>
116
  <head>
117
- <title>Studio Ghibli Generator</title>
118
  <style>
119
  body { font-family: Arial; text-align: center; padding: 2rem; background-color: #f9f9f9; }
120
  input, select, button { margin: 0.5rem; padding: 0.6rem; width: 300px; border-radius: 6px; border: 1px solid #ccc; }
121
  button { background-color: #444; color: white; cursor: pointer; }
122
  button:hover { background-color: #333; }
123
- img { margin-top: 1rem; border-radius: 12px; max-width: 512px; }
124
  </style>
125
  </head>
126
  <body>
127
- <h2>🎨 Studio Ghibli Generator</h2>
128
  <form id="generateForm">
129
- <input id="prompt" placeholder="e.g. a boy riding a dragon" required><br>
130
  <select id="style">
131
- <option value="Ghibli">Ghibli</option>
132
- <option value="GH1bli">GH1bli</option>
133
  </select><br>
134
  <input id="seed" type="number" value="42"><br>
135
  <button type="submit">Generate Image</button>
@@ -172,7 +154,7 @@ async def api_generate(request: Request):
172
  try:
173
  data = await request.json()
174
  prompt = data.get("prompt", "").strip()
175
- style = data.get("style", "Ghibli")
176
  seed = data.get("seed", 42)
177
  if not prompt:
178
  return JSONResponse({"status": "error", "message": "Prompt required"}, status_code=400)
@@ -197,7 +179,7 @@ async def api_generate(request: Request):
197
 
198
 
199
  # -----------------------------
200
- # ZeroGPU / Keep Alive Hook
201
  # -----------------------------
202
  @spaces.GPU
203
  def keep_alive():
@@ -206,6 +188,6 @@ def keep_alive():
206
 
207
  if __name__ == "__main__":
208
  import uvicorn
209
- print("🚀 Launching FastAPI (Multi-Request / CPU / ZeroGPU Mode)")
210
  keep_alive()
211
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
  import spaces
2
  import torch
3
+ import os
4
  from fastapi import FastAPI, Request
5
  from fastapi.middleware.cors import CORSMiddleware
6
  from fastapi.responses import HTMLResponse, JSONResponse
7
+ from diffusers import FluxPipeline
 
8
  from PIL import Image
9
  import base64
10
  import io
 
12
  from concurrent.futures import ThreadPoolExecutor
13
 
14
  # -----------------------------
15
+ # Hugging Face Token Support
16
  # -----------------------------
17
+ HF_TOKEN = os.getenv("HF_TOKEN") # Must be set on server / spaces
18
+
19
+ # -----------------------------
20
+ # Model (FLUX.1-schnell)
21
+ # -----------------------------
22
+ BASE_MODEL = "black-forest-labs/FLUX.1-schnell"
23
 
24
  _cached_pipelines = {}
25
+ executor = ThreadPoolExecutor(max_workers=4)
26
+ semaphore = asyncio.Semaphore(4)
 
 
27
 
28
  # -----------------------------
29
+ # Load FLUX Pipeline
30
  # -----------------------------
31
+ def load_pipeline():
32
+ if "flux" in _cached_pipelines:
33
+ return _cached_pipelines["flux"]
34
 
35
+ print("🔹 Loading FLUX.1-schnell Model")
 
 
36
 
37
+ pipe = FluxPipeline.from_pretrained(
38
  BASE_MODEL,
39
+ torch_dtype=torch.float32,
40
+ use_auth_token=HF_TOKEN, # <--- token applied here
41
+ ).to("cpu")
42
 
43
  pipe.enable_attention_slicing()
44
  pipe.enable_vae_tiling()
45
 
46
+ _cached_pipelines["flux"] = pipe
 
 
 
 
 
 
 
 
 
47
  return pipe
48
 
49
 
50
  # -----------------------------
51
+ # Image Generation
52
  # -----------------------------
53
+ def generate_image_sync(prompt: str, style: str = None, seed: int = 42):
54
+ pipe = load_pipeline()
55
  generator = torch.Generator(device="cpu").manual_seed(int(seed))
56
 
57
+ width = 1920
58
+ height = 1080
 
 
59
 
60
  image = pipe(
61
+ prompt=prompt,
62
+ width=width,
63
+ height=height,
64
+ num_inference_steps=20,
65
+ guidance_scale=3.5,
66
  generator=generator,
67
  ).images[0]
68
 
 
 
69
  return image
70
 
71
 
72
  # -----------------------------
73
+ # Async Wrapper
74
  # -----------------------------
75
  async def generate_image_async(prompt, style, seed):
76
  async with semaphore:
77
+ loop = asyncio.get_running_loop()
78
  return await loop.run_in_executor(executor, generate_image_sync, prompt, style, seed)
79
 
80
 
81
  # -----------------------------
82
  # FastAPI App Setup
83
  # -----------------------------
84
+ app = FastAPI(title="FLUX Image Generator API", version="2.1")
85
 
 
86
  app.add_middleware(
87
  CORSMiddleware,
88
+ allow_origins=["*"],
89
  allow_credentials=True,
90
  allow_methods=["*"],
91
  allow_headers=["*"],
 
97
  return """
98
  <html>
99
  <head>
100
+ <title>FLUX Generator</title>
101
  <style>
102
  body { font-family: Arial; text-align: center; padding: 2rem; background-color: #f9f9f9; }
103
  input, select, button { margin: 0.5rem; padding: 0.6rem; width: 300px; border-radius: 6px; border: 1px solid #ccc; }
104
  button { background-color: #444; color: white; cursor: pointer; }
105
  button:hover { background-color: #333; }
106
+ img { margin-top: 1rem; border-radius: 12px; max-width: 90%; }
107
  </style>
108
  </head>
109
  <body>
110
+ <h2>🎨 FLUX.1-schnell Image Generator</h2>
111
  <form id="generateForm">
112
+ <input id="prompt" placeholder="Describe your image..." required><br>
113
  <select id="style">
114
+ <option value="default">Default (FLUX)</option>
 
115
  </select><br>
116
  <input id="seed" type="number" value="42"><br>
117
  <button type="submit">Generate Image</button>
 
154
  try:
155
  data = await request.json()
156
  prompt = data.get("prompt", "").strip()
157
+ style = data.get("style", "default")
158
  seed = data.get("seed", 42)
159
  if not prompt:
160
  return JSONResponse({"status": "error", "message": "Prompt required"}, status_code=400)
 
179
 
180
 
181
  # -----------------------------
182
+ # ZeroGPU Keep Alive
183
  # -----------------------------
184
  @spaces.GPU
185
  def keep_alive():
 
188
 
189
  if __name__ == "__main__":
190
  import uvicorn
191
+ print("🚀 Launching FastAPI with FLUX + HF Token")
192
  keep_alive()
193
  uvicorn.run(app, host="0.0.0.0", port=7860)