videopix commited on
Commit
2cc5994
·
verified ·
1 Parent(s): 07d8ca5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +209 -79
app.py CHANGED
@@ -1,55 +1,87 @@
1
- import os, io, base64, asyncio, torch, spaces
 
 
 
 
2
  from fastapi import FastAPI, Request
3
  from fastapi.middleware.cors import CORSMiddleware
4
  from fastapi.responses import HTMLResponse, JSONResponse
5
- from diffusers import FluxPipeline
6
- from PIL import Image
7
  from concurrent.futures import ThreadPoolExecutor
 
8
 
9
  HF_TOKEN = os.getenv("HF_TOKEN")
10
  BASE_MODEL = "black-forest-labs/FLUX.1-schnell"
11
 
12
- _cached = {}
13
- # moderate concurrency so CPU doesn’t choke
14
  executor = ThreadPoolExecutor(max_workers=3)
15
  semaphore = asyncio.Semaphore(3)
16
 
17
- def load_pipeline():
18
- if "flux" in _cached:
19
- return _cached["flux"]
20
- print("🔹 Loading FLUX.1-schnell (fast mode)")
 
 
 
 
 
 
 
 
 
 
 
21
  pipe = FluxPipeline.from_pretrained(
22
  BASE_MODEL,
23
- torch_dtype=torch.float16,
24
  use_auth_token=HF_TOKEN,
25
- ).to("cpu", dtype=torch.float16)
26
- pipe.enable_attention_slicing()
27
- pipe.enable_vae_tiling()
28
- _cached["flux"] = pipe
29
- return pipe
30
-
31
- def generate_image_sync(prompt: str, seed: int = 42):
32
- pipe = load_pipeline()
33
- gen = torch.Generator(device="cpu").manual_seed(int(seed))
34
- # smaller size and steps for speed
35
- w, h = 768, 432
36
- image = pipe(
 
37
  prompt=prompt,
38
- width=w,
39
- height=h,
40
- num_inference_steps=4,
41
- guidance_scale=3,
42
- generator=gen,
43
  ).images[0]
44
- # slight upscale back to 960×540 to keep output clear
45
- return image.resize((960, 540), Image.BICUBIC)
46
 
 
 
 
 
 
 
 
 
 
 
47
  async def generate_image_async(prompt, seed):
48
  async with semaphore:
49
  loop = asyncio.get_running_loop()
50
- return await loop.run_in_executor(executor, generate_image_sync, prompt, seed)
 
 
 
 
 
51
 
 
 
 
 
52
  app = FastAPI(title="FLUX Fast API", version="3.1")
 
53
  app.add_middleware(
54
  CORSMiddleware,
55
  allow_origins=["*"],
@@ -58,78 +90,176 @@ app.add_middleware(
58
  allow_headers=["*"],
59
  )
60
 
 
61
  @app.get("/", response_class=HTMLResponse)
62
  def home():
63
  return """
64
- <html><head><title>FLUX Fast</title>
65
- <style>body{font-family:Arial;text-align:center;padding:2rem}
66
- input,button{margin:.5rem;padding:.6rem;width:300px;border-radius:6px;border:1px solid #ccc}
67
- button{background:#444;color:#fff}button:hover{background:#333}
68
- img{margin-top:1rem;max-width:90%;border-radius:12px}</style></head>
69
- <body><h2>🎨 FLUX Fast Generator</h2>
70
- <form id='f'><input id='prompt' placeholder='Describe image...' required><br>
71
- <input id='seed' type='number' value='42'><br>
72
- <button>Generate</button></form><div id='out'></div>
73
- <script>
74
- const form = document.getElementById("f");
75
- const promptInput = document.getElementById("prompt");
76
- const seedInput = document.getElementById("seed");
77
- const resultDiv = document.getElementById("out");
78
-
79
- form.addEventListener("submit", async (e) => {
80
- e.preventDefault();
81
- const prompt = promptInput.value.trim();
82
- if (!prompt) {
83
- resultDiv.innerHTML = "<p style='color:red'>❌ Please enter a prompt</p>";
84
- return;
85
- }
86
- resultDiv.innerHTML = "<p>⏳ Generating...</p>";
87
- const payload = {
88
- prompt: prompt,
89
- seed: parseInt(seedInput.value || 42)
90
- };
91
- const res = await fetch("/api/generate", {
92
- method: "POST",
93
- headers: { "Content-Type": "application/json" },
94
- body: JSON.stringify(payload)
95
- });
96
- const json = await res.json();
97
- if (json.status === "success") {
98
- resultDiv.innerHTML = `<img src="data:image/png;base64,${json.image_base64}"/><p>✅ Done!</p>`;
99
- } else {
100
- resultDiv.innerHTML = `<p style='color:red'>❌ ${json.message}</p>`;
101
- }
102
- });
103
- </script>
104
- </body></html>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  """
106
 
 
107
  @app.post("/api/generate")
108
  async def api_generate(request: Request):
109
  try:
110
  data = await request.json()
111
  prompt = str(data.get("prompt", "")).strip()
112
  seed = int(data.get("seed", 42))
 
113
  if not prompt:
114
  return JSONResponse({"status": "error", "message": "Prompt required"}, 400)
 
115
  except Exception:
116
  return JSONResponse({"status": "error", "message": "Invalid JSON"}, 400)
117
 
118
  try:
119
- image = await generate_image_async(prompt, seed)
120
- buf = io.BytesIO()
121
- image.save(buf, format="PNG")
122
- img64 = base64.b64encode(buf.getvalue()).decode("utf-8")
123
- return JSONResponse({"status": "success", "prompt": prompt, "image_base64": img64})
124
  except Exception as e:
125
- print(f"❌ Error: {e}")
126
  return JSONResponse({"status": "error", "message": str(e)}, 500)
127
 
 
128
  @spaces.GPU
129
- def keep_alive(): return "ZeroGPU Ready"
 
 
130
 
131
  if __name__ == "__main__":
132
  import uvicorn
133
  print("🚀 Launching Fast FLUX API")
134
  keep_alive()
135
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
+ import os
2
+ import io
3
+ import base64
4
+ import asyncio
5
+ import spaces
6
  from fastapi import FastAPI, Request
7
  from fastapi.middleware.cors import CORSMiddleware
8
  from fastapi.responses import HTMLResponse, JSONResponse
 
 
9
  from concurrent.futures import ThreadPoolExecutor
10
+ from PIL import Image
11
 
12
  HF_TOKEN = os.getenv("HF_TOKEN")
13
  BASE_MODEL = "black-forest-labs/FLUX.1-schnell"
14
 
15
+ # concurrency
 
16
  executor = ThreadPoolExecutor(max_workers=3)
17
  semaphore = asyncio.Semaphore(3)
18
 
19
+ # --------------------------------------------------------
20
+ # IMPORTANT: no torch.cuda calls, no GPU detection, no
21
+ # pipeline loading here. Only CPU-safe imports.
22
+ # --------------------------------------------------------
23
+ from diffusers import FluxPipeline
24
+ import torch
25
+
26
+ # --------------------------------------------------------
27
+ # GPU function: runs in a separate GPU worker process.
28
+ # Full model load + inference must live here.
29
+ # --------------------------------------------------------
30
+ @spaces.GPU
31
+ def gpu_generate(prompt: str, seed: int):
32
+ print("⚡ ZeroGPU worker starting model load + inference")
33
+
34
  pipe = FluxPipeline.from_pretrained(
35
  BASE_MODEL,
36
+ torch_dtype=torch.float16, # safe on GPU worker
37
  use_auth_token=HF_TOKEN,
38
+ low_cpu_mem_usage=True
39
+ ).to("cuda")
40
+
41
+ try:
42
+ pipe.enable_attention_slicing()
43
+ pipe.enable_vae_tiling()
44
+ pipe.enable_xformers_memory_efficient_attention()
45
+ except Exception:
46
+ pass
47
+
48
+ generator = torch.Generator(device="cuda").manual_seed(seed)
49
+
50
+ img = pipe(
51
  prompt=prompt,
52
+ width=768,
53
+ height=432,
54
+ num_inference_steps=6,
55
+ guidance_scale=2.5,
56
+ generator=generator,
57
  ).images[0]
 
 
58
 
59
+ img = img.resize((960, 540), Image.BICUBIC)
60
+
61
+ buf = io.BytesIO()
62
+ img.save(buf, format="PNG")
63
+ return base64.b64encode(buf.getvalue()).decode("utf-8")
64
+
65
+
66
+ # --------------------------------------------------------
67
+ # Async wrapper to allow multiple simultaneous requests
68
+ # --------------------------------------------------------
69
  async def generate_image_async(prompt, seed):
70
  async with semaphore:
71
  loop = asyncio.get_running_loop()
72
+ return await loop.run_in_executor(
73
+ executor,
74
+ gpu_generate,
75
+ prompt,
76
+ seed
77
+ )
78
 
79
+
80
+ # --------------------------------------------------------
81
+ # FastAPI app
82
+ # --------------------------------------------------------
83
  app = FastAPI(title="FLUX Fast API", version="3.1")
84
+
85
  app.add_middleware(
86
  CORSMiddleware,
87
  allow_origins=["*"],
 
90
  allow_headers=["*"],
91
  )
92
 
93
+
94
  @app.get("/", response_class=HTMLResponse)
95
  def home():
96
  return """
97
+ <!doctype html>
98
+ <html lang="en">
99
+ <head>
100
+ <meta charset="utf-8" />
101
+ <meta name="viewport" content="width=device-width,initial-scale=1" />
102
+ <title>FLUX Fast Generator</title>
103
+ <style>
104
+ :root{font-family:Inter, Roboto, Arial, sans-serif; color:#111}
105
+ body{max-width:900px;margin:32px auto;padding:24px;line-height:1.45}
106
+ h1{font-size:1.6rem;margin:0 0 12px}
107
+ p.lead{color:#444;margin:0 0 18px}
108
+ .card{border:1px solid #e6e6e6;border-radius:12px;padding:18px;box-shadow:0 4px 14px rgba(20,20,20,0.03)}
109
+ label{display:block;margin:12px 0 6px;font-weight:600}
110
+ input[type="text"], input[type="number"], textarea{
111
+ width:100%;box-sizing:border-box;padding:10px;border-radius:8px;border:1px solid #d5d5d5;font-size:14px
112
+ }
113
+ textarea{min-height:100px;resize:vertical}
114
+ .row{display:flex;gap:12px;align-items:center;margin-top:12px}
115
+ button{padding:10px 16px;border-radius:8px;border:0;background:#111;color:#fff;cursor:pointer}
116
+ button.secondary{background:#f3f3f3;color:#111;border:1px solid #ddd}
117
+ button:disabled{opacity:0.6;cursor:not-allowed}
118
+ .meta{font-size:13px;color:#666;margin-top:8px}
119
+ .result{margin-top:18px;text-align:center}
120
+ .result img{max-width:100%;border-radius:12px;box-shadow:0 6px 30px rgba(0,0,0,0.06)}
121
+ .footer{margin-top:18px;font-size:13px;color:#666;text-align:center}
122
+ .progress{display:inline-flex;align-items:center;gap:10px}
123
+ .spinner{
124
+ width:18px;height:18px;border-radius:50%;border:3px solid rgba(0,0,0,0.08);border-top-color:#111;animation:spin 1s linear infinite
125
+ }
126
+ @keyframes spin{to{transform:rotate(360deg)}}
127
+ .download{display:inline-block;margin-top:8px;padding:8px 12px;border-radius:8px;background:#fff;border:1px solid #ddd;color:#111;text-decoration:none}
128
+ </style>
129
+ </head>
130
+ <body>
131
+ <h1>FLUX Fast Generator</h1>
132
+ <p class="lead">Enter a prompt and press Generate. The backend runs model inference and returns the generated image.</p>
133
+ <div class="card">
134
+ <form id="genForm">
135
+ <label for="prompt">Prompt</label>
136
+ <textarea id="prompt" placeholder="A scene of a futuristic city at golden hour, cinematic lighting, ultra-detailed..." required></textarea>
137
+ <div style="display:flex;gap:12px;flex-wrap:wrap;margin-top:8px;">
138
+ <div style="flex:1;min-width:160px">
139
+ <label for="seed">Seed (optional)</label>
140
+ <input id="seed" type="number" value="42" />
141
+ </div>
142
+ <div style="width:160px">
143
+ <label for="steps">Steps</label>
144
+ <input id="steps" type="number" value="6" min="1" max="50" />
145
+ </div>
146
+ <div style="width:160px">
147
+ <label for="scale">Guidance</label>
148
+ <input id="scale" type="number" step="0.1" value="2.5" min="1" max="20" />
149
+ </div>
150
+ </div>
151
+ <div class="row" style="margin-top:18px">
152
+ <button id="genBtn" type="submit">Generate</button>
153
+ <button id="clearBtn" type="button" class="secondary">Clear</button>
154
+ <div class="meta" id="status" style="margin-left:auto"></div>
155
+ </div>
156
+ </form>
157
+ <div class="result" id="resultArea" aria-live="polite"></div>
158
+ </div>
159
+ <div class="footer">Tip: keep steps and resolution low for faster results in CPU or cold GPU environments.</div>
160
+ <script>
161
+ const form = document.getElementById('genForm');
162
+ const promptInput = document.getElementById('prompt');
163
+ const seedInput = document.getElementById('seed');
164
+ const stepsInput = document.getElementById('steps');
165
+ const scaleInput = document.getElementById('scale');
166
+ const genBtn = document.getElementById('genBtn');
167
+ const clearBtn = document.getElementById('clearBtn');
168
+ const status = document.getElementById('status');
169
+ const resultArea = document.getElementById('resultArea');
170
+ clearBtn.addEventListener('click', () => {
171
+ promptInput.value = '';
172
+ resultArea.innerHTML = '';
173
+ status.textContent = '';
174
+ });
175
+ form.addEventListener('submit', async (e) => {
176
+ e.preventDefault();
177
+ const prompt = promptInput.value.trim();
178
+ if (!prompt) {
179
+ status.textContent = 'Please enter a prompt';
180
+ return;
181
+ }
182
+ const payload = {
183
+ prompt: prompt,
184
+ seed: parseInt(seedInput.value || 42),
185
+ num_inference_steps: parseInt(stepsInput.value || 6),
186
+ guidance_scale: parseFloat(scaleInput.value || 2.5)
187
+ };
188
+ // UI state
189
+ genBtn.disabled = true;
190
+ clearBtn.disabled = true;
191
+ status.innerHTML = '<span class="progress"><span class="spinner"></span> Generating...</span>';
192
+ resultArea.innerHTML = '';
193
+ const start = Date.now();
194
+ try {
195
+ const res = await fetch('/api/generate', {
196
+ method: 'POST',
197
+ headers: {'Content-Type': 'application/json'},
198
+ body: JSON.stringify(payload)
199
+ });
200
+ const json = await res.json();
201
+ if (!res.ok || json.status !== 'success') {
202
+ const msg = json && json.message ? json.message : 'Generation failed';
203
+ status.textContent = 'Error: ' + msg;
204
+ genBtn.disabled = false;
205
+ clearBtn.disabled = false;
206
+ return;
207
+ }
208
+ const took = ((Date.now() - start) / 1000).toFixed(1);
209
+ status.textContent = `Done in ${took}s`;
210
+ const imgData = 'data:image/png;base64,' + json.image_base64;
211
+ const img = document.createElement('img');
212
+ img.src = imgData;
213
+ img.alt = prompt;
214
+ resultArea.appendChild(img);
215
+ const dl = document.createElement('a');
216
+ dl.href = imgData;
217
+ dl.download = 'flux_gen.png';
218
+ dl.className = 'download';
219
+ dl.textContent = 'Download PNG';
220
+ resultArea.appendChild(dl);
221
+ } catch (err) {
222
+ console.error(err);
223
+ status.textContent = 'Network or server error';
224
+ } finally {
225
+ genBtn.disabled = false;
226
+ clearBtn.disabled = false;
227
+ }
228
+ });
229
+ </script>
230
+ </body>
231
+ </html>
232
  """
233
 
234
+
235
  @app.post("/api/generate")
236
  async def api_generate(request: Request):
237
  try:
238
  data = await request.json()
239
  prompt = str(data.get("prompt", "")).strip()
240
  seed = int(data.get("seed", 42))
241
+
242
  if not prompt:
243
  return JSONResponse({"status": "error", "message": "Prompt required"}, 400)
244
+
245
  except Exception:
246
  return JSONResponse({"status": "error", "message": "Invalid JSON"}, 400)
247
 
248
  try:
249
+ img64 = await generate_image_async(prompt, seed)
250
+ return JSONResponse({"status": "success", "image_base64": img64, "prompt": prompt})
 
 
 
251
  except Exception as e:
252
+ print("❌ Error:", e)
253
  return JSONResponse({"status": "error", "message": str(e)}, 500)
254
 
255
+
256
  @spaces.GPU
257
+ def keep_alive():
258
+ return "ZeroGPU Ready"
259
+
260
 
261
  if __name__ == "__main__":
262
  import uvicorn
263
  print("🚀 Launching Fast FLUX API")
264
  keep_alive()
265
+ uvicorn.run(app, host="0.0.0.0", port=7860)