videopix commited on
Commit
4cb20ce
·
verified ·
1 Parent(s): 3a9efeb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -138
app.py CHANGED
@@ -1,89 +1,55 @@
1
- import spaces
2
- import torch
3
- import os
4
  from fastapi import FastAPI, Request
5
  from fastapi.middleware.cors import CORSMiddleware
6
  from fastapi.responses import HTMLResponse, JSONResponse
7
  from diffusers import FluxPipeline
8
  from PIL import Image
9
- import base64
10
- import io
11
- import asyncio
12
  from concurrent.futures import ThreadPoolExecutor
13
 
14
- # -----------------------------
15
- # Hugging Face Token Support
16
- # -----------------------------
17
- HF_TOKEN = os.getenv("HF_TOKEN") # Must be set on server / spaces
18
-
19
- # -----------------------------
20
- # Model (FLUX.1-schnell)
21
- # -----------------------------
22
  BASE_MODEL = "black-forest-labs/FLUX.1-schnell"
23
 
24
- _cached_pipelines = {}
25
- executor = ThreadPoolExecutor(max_workers=2) # reduce thread pressure for CPU
26
- semaphore = asyncio.Semaphore(1) # avoids CPU overload
 
27
 
28
- # -----------------------------
29
- # Load FLUX Pipeline (Optimized)
30
- # -----------------------------
31
  def load_pipeline():
32
- if "flux" in _cached_pipelines:
33
- return _cached_pipelines["flux"]
34
-
35
- print("🔹 Loading FLUX.1-schnell Model (Optimized for CPU)")
36
-
37
  pipe = FluxPipeline.from_pretrained(
38
  BASE_MODEL,
39
- torch_dtype=torch.float16, # <= Faster
40
  use_auth_token=HF_TOKEN,
41
- )
42
-
43
- pipe.to("cpu", dtype=torch.float16) # <= Ensure CPU uses FP16
44
  pipe.enable_attention_slicing()
45
  pipe.enable_vae_tiling()
46
-
47
- _cached_pipelines["flux"] = pipe
48
  return pipe
49
 
50
-
51
- # -----------------------------
52
- # Image Generation (Optimized)
53
- # -----------------------------
54
- def generate_image_sync(prompt: str, style: str = None, seed: int = 42):
55
  pipe = load_pipeline()
56
- generator = torch.Generator(device="cpu").manual_seed(int(seed))
57
-
58
- width = 1024 # <= Faster resolution
59
- height = 576
60
-
61
  image = pipe(
62
  prompt=prompt,
63
- width=width,
64
- height=height,
65
- num_inference_steps=8, # <= Lower steps for speed
66
- guidance_scale=2.8, # <= Balanced for FLUX
67
- generator=generator,
68
  ).images[0]
 
 
69
 
70
- return image
71
-
72
-
73
- # -----------------------------
74
- # Async Wrapper
75
- # -----------------------------
76
- async def generate_image_async(prompt, style, seed):
77
  async with semaphore:
78
  loop = asyncio.get_running_loop()
79
- return await loop.run_in_executor(executor, generate_image_sync, prompt, style, seed)
80
-
81
-
82
- # -----------------------------
83
- # FastAPI App Setup
84
- # -----------------------------
85
- app = FastAPI(title="FLUX Image Generator API", version="2.2")
86
 
 
87
  app.add_middleware(
88
  CORSMiddleware,
89
  allow_origins=["*"],
@@ -92,103 +58,57 @@ app.add_middleware(
92
  allow_headers=["*"],
93
  )
94
 
95
-
96
  @app.get("/", response_class=HTMLResponse)
97
  def home():
98
  return """
99
- <html>
100
- <head>
101
- <title>FLUX Generator</title>
102
- <style>
103
- body { font-family: Arial; text-align: center; padding: 2rem; background-color: #f9f9f9; }
104
- input, select, button { margin: 0.5rem; padding: 0.6rem; width: 300px; border-radius: 6px; border: 1px solid #ccc; }
105
- button { background-color: #444; color: white; cursor: pointer; }
106
- button:hover { background-color: #333; }
107
- img { margin-top: 1rem; border-radius: 12px; max-width: 90%; }
108
- </style>
109
- </head>
110
- <body>
111
- <h2>🎨 FLUX.1-schnell Image Generator</h2>
112
- <form id="generateForm">
113
- <input id="prompt" placeholder="Describe your image..." required><br>
114
- <select id="style">
115
- <option value="default">Default (FLUX)</option>
116
- </select><br>
117
- <input id="seed" type="number" value="42"><br>
118
- <button type="submit">Generate Image</button>
119
- </form>
120
- <div id="result"></div>
121
- <script>
122
- const form = document.getElementById("generateForm");
123
- const resultDiv = document.getElementById("result");
124
- form.addEventListener("submit", async (e) => {
125
- e.preventDefault();
126
- resultDiv.innerHTML = "<p>⏳ Generating image...</p>";
127
- const data = {
128
- prompt: document.getElementById("prompt").value,
129
- style: document.getElementById("style").value,
130
- seed: parseInt(document.getElementById("seed").value)
131
- };
132
- const res = await fetch("/api/generate", {
133
- method: "POST",
134
- headers: { "Content-Type": "application/json" },
135
- body: JSON.stringify(data)
136
- });
137
- const json = await res.json();
138
- if (json.status === "success") {
139
- resultDiv.innerHTML = `<img src="data:image/png;base64,${json.image_base64}"/><p>✅ Done!</p>`;
140
- } else {
141
- resultDiv.innerHTML = `<p style='color:red'>❌ ${json.message}</p>`;
142
- }
143
- });
144
- </script>
145
- </body>
146
- </html>
147
  """
148
 
149
-
150
- # -----------------------------
151
- # API Endpoint
152
- # -----------------------------
153
  @app.post("/api/generate")
154
  async def api_generate(request: Request):
155
  try:
156
  data = await request.json()
157
- prompt = data.get("prompt", "").strip()
158
- style = data.get("style", "default")
159
- seed = data.get("seed", 42)
160
  if not prompt:
161
- return JSONResponse({"status": "error", "message": "Prompt required"}, status_code=400)
162
  except Exception:
163
- return JSONResponse({"status": "error", "message": "Invalid JSON"}, status_code=400)
164
 
165
  try:
166
- image = await generate_image_async(prompt, style, seed)
167
- buffer = io.BytesIO()
168
- image.save(buffer, format="PNG")
169
- img_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
170
-
171
- return JSONResponse({
172
- "status": "success",
173
- "prompt": prompt,
174
- "style": style,
175
- "image_base64": img_base64
176
- })
177
  except Exception as e:
178
  print(f"❌ Error: {e}")
179
- return JSONResponse({"status": "error", "message": str(e)}, status_code=500)
180
 
181
-
182
- # -----------------------------
183
- # ZeroGPU Keep Alive
184
- # -----------------------------
185
  @spaces.GPU
186
- def keep_alive():
187
- return "ZeroGPU Ready"
188
-
189
 
190
  if __name__ == "__main__":
191
  import uvicorn
192
- print("🚀 Launching FastAPI (FLUX optimized for CPU)")
193
  keep_alive()
194
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
+ import os, io, base64, asyncio, torch, spaces
 
 
2
  from fastapi import FastAPI, Request
3
  from fastapi.middleware.cors import CORSMiddleware
4
  from fastapi.responses import HTMLResponse, JSONResponse
5
  from diffusers import FluxPipeline
6
  from PIL import Image
 
 
 
7
  from concurrent.futures import ThreadPoolExecutor
8
 
9
+ HF_TOKEN = os.getenv("HF_TOKEN")
 
 
 
 
 
 
 
10
  BASE_MODEL = "black-forest-labs/FLUX.1-schnell"
11
 
12
+ _cached = {}
13
+ # moderate concurrency so CPU doesn’t choke
14
+ executor = ThreadPoolExecutor(max_workers=3)
15
+ semaphore = asyncio.Semaphore(3)
16
 
 
 
 
17
  def load_pipeline():
18
+ if "flux" in _cached:
19
+ return _cached["flux"]
20
+ print("🔹 Loading FLUX.1-schnell (fast mode)")
 
 
21
  pipe = FluxPipeline.from_pretrained(
22
  BASE_MODEL,
23
+ torch_dtype=torch.float16,
24
  use_auth_token=HF_TOKEN,
25
+ ).to("cpu", dtype=torch.float16)
 
 
26
  pipe.enable_attention_slicing()
27
  pipe.enable_vae_tiling()
28
+ _cached["flux"] = pipe
 
29
  return pipe
30
 
31
+ def generate_image_sync(prompt: str, seed: int = 42):
 
 
 
 
32
  pipe = load_pipeline()
33
+ gen = torch.Generator(device="cpu").manual_seed(int(seed))
34
+ # smaller size and steps for speed
35
+ w, h = 768, 432
 
 
36
  image = pipe(
37
  prompt=prompt,
38
+ width=w,
39
+ height=h,
40
+ num_inference_steps=6,
41
+ guidance_scale=2.5,
42
+ generator=gen,
43
  ).images[0]
44
+ # slight upscale back to 960×540 to keep output clear
45
+ return image.resize((960, 540), Image.BICUBIC)
46
 
47
+ async def generate_image_async(prompt, seed):
 
 
 
 
 
 
48
  async with semaphore:
49
  loop = asyncio.get_running_loop()
50
+ return await loop.run_in_executor(executor, generate_image_sync, prompt, seed)
 
 
 
 
 
 
51
 
52
+ app = FastAPI(title="FLUX Fast API", version="3.1")
53
  app.add_middleware(
54
  CORSMiddleware,
55
  allow_origins=["*"],
 
58
  allow_headers=["*"],
59
  )
60
 
 
61
  @app.get("/", response_class=HTMLResponse)
62
  def home():
63
  return """
64
+ <html><head><title>FLUX Fast</title>
65
+ <style>body{font-family:Arial;text-align:center;padding:2rem}
66
+ input,button{margin:.5rem;padding:.6rem;width:300px;border-radius:6px;border:1px solid #ccc}
67
+ button{background:#444;color:#fff}button:hover{background:#333}
68
+ img{margin-top:1rem;max-width:90%;border-radius:12px}</style></head>
69
+ <body><h2>🎨 FLUX Fast Generator</h2>
70
+ <form id='f'><input id='prompt' placeholder='Describe image...' required><br>
71
+ <input id='seed' type='number' value='42'><br>
72
+ <button>Generate</button></form><div id='out'></div>
73
+ <script>
74
+ const f=document.getElementById('f'),o=document.getElementById('out');
75
+ f.addEventListener('submit',async e=>{
76
+ e.preventDefault();o.innerHTML='⏳ Generating...';
77
+ const res=await fetch('/api/generate',{method:'POST',headers:{'Content-Type':'application/json'},
78
+ body:JSON.stringify({prompt:prompt.value,seed:+seed.value})});
79
+ const j=await res.json();
80
+ if(j.status==='success')o.innerHTML=`<img src="data:image/png;base64,${j.image_base64}"/><p>✅ Done!</p>`;
81
+ else o.innerHTML=`<p style='color:red'>❌ ${j.message}</p>`;
82
+ });
83
+ </script></body></html>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  """
85
 
 
 
 
 
86
  @app.post("/api/generate")
87
  async def api_generate(request: Request):
88
  try:
89
  data = await request.json()
90
+ prompt = str(data.get("prompt", "")).strip()
91
+ seed = int(data.get("seed", 42))
 
92
  if not prompt:
93
+ return JSONResponse({"status": "error", "message": "Prompt required"}, 400)
94
  except Exception:
95
+ return JSONResponse({"status": "error", "message": "Invalid JSON"}, 400)
96
 
97
  try:
98
+ image = await generate_image_async(prompt, seed)
99
+ buf = io.BytesIO()
100
+ image.save(buf, format="PNG")
101
+ img64 = base64.b64encode(buf.getvalue()).decode("utf-8")
102
+ return JSONResponse({"status": "success", "prompt": prompt, "image_base64": img64})
 
 
 
 
 
 
103
  except Exception as e:
104
  print(f"❌ Error: {e}")
105
+ return JSONResponse({"status": "error", "message": str(e)}, 500)
106
 
 
 
 
 
107
  @spaces.GPU
108
+ def keep_alive(): return "ZeroGPU Ready"
 
 
109
 
110
  if __name__ == "__main__":
111
  import uvicorn
112
+ print("🚀 Launching Fast FLUX API")
113
  keep_alive()
114
  uvicorn.run(app, host="0.0.0.0", port=7860)