videopix commited on
Commit
7fb308a
·
verified ·
1 Parent(s): 2cc5994

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +172 -194
app.py CHANGED
@@ -1,86 +1,109 @@
1
- import os
2
  import io
 
3
  import base64
4
  import asyncio
5
- import spaces
 
 
6
  from fastapi import FastAPI, Request
7
  from fastapi.middleware.cors import CORSMiddleware
8
  from fastapi.responses import HTMLResponse, JSONResponse
9
- from concurrent.futures import ThreadPoolExecutor
10
  from PIL import Image
 
 
11
 
12
- HF_TOKEN = os.getenv("HF_TOKEN")
13
- BASE_MODEL = "black-forest-labs/FLUX.1-schnell"
14
 
15
- # concurrency
16
- executor = ThreadPoolExecutor(max_workers=3)
17
- semaphore = asyncio.Semaphore(3)
 
18
 
19
- # --------------------------------------------------------
20
- # IMPORTANT: no torch.cuda calls, no GPU detection, no
21
- # pipeline loading here. Only CPU-safe imports.
22
- # --------------------------------------------------------
23
- from diffusers import FluxPipeline
24
- import torch
25
 
26
- # --------------------------------------------------------
27
- # GPU function: runs in a separate GPU worker process.
28
- # Full model load + inference must live here.
29
- # --------------------------------------------------------
30
- @spaces.GPU
31
- def gpu_generate(prompt: str, seed: int):
32
- print("⚡ ZeroGPU worker starting model load + inference")
33
-
34
- pipe = FluxPipeline.from_pretrained(
35
- BASE_MODEL,
36
- torch_dtype=torch.float16, # safe on GPU worker
37
- use_auth_token=HF_TOKEN,
38
- low_cpu_mem_usage=True
39
- ).to("cuda")
 
 
 
 
 
40
 
 
 
41
  try:
42
- pipe.enable_attention_slicing()
43
- pipe.enable_vae_tiling()
44
- pipe.enable_xformers_memory_efficient_attention()
45
  except Exception:
46
  pass
47
 
48
- generator = torch.Generator(device="cuda").manual_seed(seed)
 
49
 
50
- img = pipe(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  prompt=prompt,
52
- width=768,
53
- height=432,
54
- num_inference_steps=6,
55
- guidance_scale=2.5,
 
56
  generator=generator,
57
- ).images[0]
 
 
58
 
59
- img = img.resize((960, 540), Image.BICUBIC)
60
 
61
- buf = io.BytesIO()
62
- img.save(buf, format="PNG")
63
- return base64.b64encode(buf.getvalue()).decode("utf-8")
 
 
64
 
65
 
66
- # --------------------------------------------------------
67
- # Async wrapper to allow multiple simultaneous requests
68
- # --------------------------------------------------------
69
- async def generate_image_async(prompt, seed):
70
  async with semaphore:
71
  loop = asyncio.get_running_loop()
72
  return await loop.run_in_executor(
73
  executor,
74
- gpu_generate,
75
  prompt,
76
- seed
 
 
 
 
 
77
  )
78
 
79
 
80
- # --------------------------------------------------------
81
- # FastAPI app
82
- # --------------------------------------------------------
83
- app = FastAPI(title="FLUX Fast API", version="3.1")
84
 
85
  app.add_middleware(
86
  CORSMiddleware,
@@ -91,175 +114,130 @@ app.add_middleware(
91
  )
92
 
93
 
 
 
 
94
  @app.get("/", response_class=HTMLResponse)
95
  def home():
96
  return """
97
  <!doctype html>
98
- <html lang="en">
99
  <head>
100
  <meta charset="utf-8" />
101
- <meta name="viewport" content="width=device-width,initial-scale=1" />
102
- <title>FLUX Fast Generator</title>
103
  <style>
104
- :root{font-family:Inter, Roboto, Arial, sans-serif; color:#111}
105
- body{max-width:900px;margin:32px auto;padding:24px;line-height:1.45}
106
- h1{font-size:1.6rem;margin:0 0 12px}
107
- p.lead{color:#444;margin:0 0 18px}
108
- .card{border:1px solid #e6e6e6;border-radius:12px;padding:18px;box-shadow:0 4px 14px rgba(20,20,20,0.03)}
109
- label{display:block;margin:12px 0 6px;font-weight:600}
110
- input[type="text"], input[type="number"], textarea{
111
- width:100%;box-sizing:border-box;padding:10px;border-radius:8px;border:1px solid #d5d5d5;font-size:14px
112
- }
113
- textarea{min-height:100px;resize:vertical}
114
- .row{display:flex;gap:12px;align-items:center;margin-top:12px}
115
- button{padding:10px 16px;border-radius:8px;border:0;background:#111;color:#fff;cursor:pointer}
116
- button.secondary{background:#f3f3f3;color:#111;border:1px solid #ddd}
117
- button:disabled{opacity:0.6;cursor:not-allowed}
118
- .meta{font-size:13px;color:#666;margin-top:8px}
119
- .result{margin-top:18px;text-align:center}
120
- .result img{max-width:100%;border-radius:12px;box-shadow:0 6px 30px rgba(0,0,0,0.06)}
121
- .footer{margin-top:18px;font-size:13px;color:#666;text-align:center}
122
- .progress{display:inline-flex;align-items:center;gap:10px}
123
- .spinner{
124
- width:18px;height:18px;border-radius:50%;border:3px solid rgba(0,0,0,0.08);border-top-color:#111;animation:spin 1s linear infinite
125
- }
126
- @keyframes spin{to{transform:rotate(360deg)}}
127
- .download{display:inline-block;margin-top:8px;padding:8px 12px;border-radius:8px;background:#fff;border:1px solid #ddd;color:#111;text-decoration:none}
128
  </style>
129
  </head>
130
  <body>
131
- <h1>FLUX Fast Generator</h1>
132
- <p class="lead">Enter a prompt and press Generate. The backend runs model inference and returns the generated image.</p>
133
- <div class="card">
134
- <form id="genForm">
135
- <label for="prompt">Prompt</label>
136
- <textarea id="prompt" placeholder="A scene of a futuristic city at golden hour, cinematic lighting, ultra-detailed..." required></textarea>
137
- <div style="display:flex;gap:12px;flex-wrap:wrap;margin-top:8px;">
138
- <div style="flex:1;min-width:160px">
139
- <label for="seed">Seed (optional)</label>
140
- <input id="seed" type="number" value="42" />
141
- </div>
142
- <div style="width:160px">
143
- <label for="steps">Steps</label>
144
- <input id="steps" type="number" value="6" min="1" max="50" />
145
- </div>
146
- <div style="width:160px">
147
- <label for="scale">Guidance</label>
148
- <input id="scale" type="number" step="0.1" value="2.5" min="1" max="20" />
149
- </div>
150
- </div>
151
- <div class="row" style="margin-top:18px">
152
- <button id="genBtn" type="submit">Generate</button>
153
- <button id="clearBtn" type="button" class="secondary">Clear</button>
154
- <div class="meta" id="status" style="margin-left:auto"></div>
155
- </div>
156
- </form>
157
- <div class="result" id="resultArea" aria-live="polite"></div>
158
- </div>
159
- <div class="footer">Tip: keep steps and resolution low for faster results in CPU or cold GPU environments.</div>
160
- <script>
161
- const form = document.getElementById('genForm');
162
- const promptInput = document.getElementById('prompt');
163
- const seedInput = document.getElementById('seed');
164
- const stepsInput = document.getElementById('steps');
165
- const scaleInput = document.getElementById('scale');
166
- const genBtn = document.getElementById('genBtn');
167
- const clearBtn = document.getElementById('clearBtn');
168
- const status = document.getElementById('status');
169
- const resultArea = document.getElementById('resultArea');
170
- clearBtn.addEventListener('click', () => {
171
- promptInput.value = '';
172
- resultArea.innerHTML = '';
173
- status.textContent = '';
174
- });
175
- form.addEventListener('submit', async (e) => {
176
- e.preventDefault();
177
- const prompt = promptInput.value.trim();
178
- if (!prompt) {
179
- status.textContent = 'Please enter a prompt';
180
- return;
181
- }
182
- const payload = {
183
- prompt: prompt,
184
- seed: parseInt(seedInput.value || 42),
185
- num_inference_steps: parseInt(stepsInput.value || 6),
186
- guidance_scale: parseFloat(scaleInput.value || 2.5)
187
- };
188
- // UI state
189
- genBtn.disabled = true;
190
- clearBtn.disabled = true;
191
- status.innerHTML = '<span class="progress"><span class="spinner"></span> Generating...</span>';
192
- resultArea.innerHTML = '';
193
- const start = Date.now();
194
- try {
195
- const res = await fetch('/api/generate', {
196
- method: 'POST',
197
- headers: {'Content-Type': 'application/json'},
198
- body: JSON.stringify(payload)
199
- });
200
- const json = await res.json();
201
- if (!res.ok || json.status !== 'success') {
202
- const msg = json && json.message ? json.message : 'Generation failed';
203
- status.textContent = 'Error: ' + msg;
204
- genBtn.disabled = false;
205
- clearBtn.disabled = false;
206
- return;
207
  }
208
- const took = ((Date.now() - start) / 1000).toFixed(1);
209
- status.textContent = `Done in ${took}s`;
210
- const imgData = 'data:image/png;base64,' + json.image_base64;
211
- const img = document.createElement('img');
212
- img.src = imgData;
213
- img.alt = prompt;
214
- resultArea.appendChild(img);
215
- const dl = document.createElement('a');
216
- dl.href = imgData;
217
- dl.download = 'flux_gen.png';
218
- dl.className = 'download';
219
- dl.textContent = 'Download PNG';
220
- resultArea.appendChild(dl);
221
- } catch (err) {
222
- console.error(err);
223
- status.textContent = 'Network or server error';
224
- } finally {
225
- genBtn.disabled = false;
226
- clearBtn.disabled = false;
227
- }
228
- });
229
- </script>
230
  </body>
231
  </html>
232
  """
233
 
234
 
 
 
 
235
  @app.post("/api/generate")
236
  async def api_generate(request: Request):
 
237
  try:
238
  data = await request.json()
239
- prompt = str(data.get("prompt", "")).strip()
240
- seed = int(data.get("seed", 42))
241
-
242
- if not prompt:
243
- return JSONResponse({"status": "error", "message": "Prompt required"}, 400)
244
-
245
  except Exception:
246
  return JSONResponse({"status": "error", "message": "Invalid JSON"}, 400)
247
 
 
 
 
 
 
 
 
 
 
 
248
  try:
249
- img64 = await generate_image_async(prompt, seed)
250
- return JSONResponse({"status": "success", "image_base64": img64, "prompt": prompt})
251
- except Exception as e:
252
- print("❌ Error:", e)
253
- return JSONResponse({"status": "error", "message": str(e)}, 500)
254
 
 
 
 
 
 
 
 
 
 
 
 
 
 
255
 
256
- @spaces.GPU
257
- def keep_alive():
258
- return "ZeroGPU Ready"
259
 
260
 
 
 
 
261
  if __name__ == "__main__":
262
  import uvicorn
263
- print("🚀 Launching Fast FLUX API")
264
- keep_alive()
265
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
 
1
  import io
2
+ import os
3
  import base64
4
  import asyncio
5
+ import random
6
+ from concurrent.futures import ThreadPoolExecutor
7
+
8
  from fastapi import FastAPI, Request
9
  from fastapi.middleware.cors import CORSMiddleware
10
  from fastapi.responses import HTMLResponse, JSONResponse
11
+
12
  from PIL import Image
13
+ import torch
14
+ from diffusers import DiffusionPipeline
15
 
 
 
16
 
17
+ # -------------------------------------------------------------
18
+ # HuggingFace Token (optional)
19
+ # -------------------------------------------------------------
20
+ HF_TOKEN = os.getenv("HF_TOKEN") # <-- added
21
 
 
 
 
 
 
 
22
 
23
+ # -------------------------------------------------------------
24
+ # Model / device setup
25
+ # -------------------------------------------------------------
26
+ MODEL_REPO = "stabilityai/sdxl-turbo"
27
+
28
+ device = "cuda" if torch.cuda.is_available() else "cpu"
29
+ dtype = torch.float16 if torch.cuda.is_available() else torch.float32
30
+
31
+ print(f"Loading {MODEL_REPO} on {device} with dtype={dtype}...")
32
+
33
+ # Load with token if present
34
+ pipe = DiffusionPipeline.from_pretrained(
35
+ MODEL_REPO,
36
+ torch_dtype=dtype,
37
+ use_safetensors=True,
38
+ token=HF_TOKEN if HF_TOKEN else None, # <-- added
39
+ )
40
+
41
+ pipe.to(device)
42
 
43
+ # Optional CPU optimization
44
+ if device == "cpu":
45
  try:
46
+ pipe.enable_model_cpu_offload()
 
 
47
  except Exception:
48
  pass
49
 
50
+ print("Model ready.")
51
+
52
 
53
+ # -------------------------------------------------------------
54
+ # Image generation core
55
+ # -------------------------------------------------------------
56
+ def generate_image(
57
+ prompt: str,
58
+ negative_prompt: str,
59
+ seed: int,
60
+ width: int,
61
+ height: int,
62
+ num_inference_steps: int,
63
+ guidance_scale: float,
64
+ ):
65
+ generator = torch.Generator(device=device).manual_seed(seed)
66
+
67
+ out = pipe(
68
  prompt=prompt,
69
+ negative_prompt=negative_prompt if negative_prompt else None,
70
+ guidance_scale=guidance_scale,
71
+ num_inference_steps=num_inference_steps,
72
+ width=width,
73
+ height=height,
74
  generator=generator,
75
+ )
76
+
77
+ return out.images[0]
78
 
 
79
 
80
+ # -------------------------------------------------------------
81
+ # Async Queue
82
+ # -------------------------------------------------------------
83
+ executor = ThreadPoolExecutor(max_workers=2)
84
+ semaphore = asyncio.Semaphore(2)
85
 
86
 
87
+ async def run_generate(prompt, negative_prompt, seed, width, height, steps, guidance):
 
 
 
88
  async with semaphore:
89
  loop = asyncio.get_running_loop()
90
  return await loop.run_in_executor(
91
  executor,
92
+ generate_image,
93
  prompt,
94
+ negative_prompt,
95
+ seed,
96
+ width,
97
+ height,
98
+ steps,
99
+ guidance,
100
  )
101
 
102
 
103
+ # -------------------------------------------------------------
104
+ # FastAPI App
105
+ # -------------------------------------------------------------
106
+ app = FastAPI(title="SDXL Turbo Text2Image", version="1.0")
107
 
108
  app.add_middleware(
109
  CORSMiddleware,
 
114
  )
115
 
116
 
117
+ # -------------------------------------------------------------
118
+ # Simple Web UI
119
+ # -------------------------------------------------------------
120
  @app.get("/", response_class=HTMLResponse)
121
  def home():
122
  return """
123
  <!doctype html>
124
+ <html>
125
  <head>
126
  <meta charset="utf-8" />
127
+ <title>SDXL Turbo CPU Generator</title>
 
128
  <style>
129
+ body { font-family: Arial; max-width: 900px; margin: 30px auto; }
130
+ textarea { width: 100%; padding: 10px; border-radius: 6px; border: 1px solid #ccc; margin-bottom: 10px; }
131
+ button { padding: 12px 18px; background:black; color:white; border:none; cursor:pointer; margin-top:10px; }
132
+ img { margin-top:20px; max-width:100%; border-radius:10px; }
133
+ #status { margin-top:10px; }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  </style>
135
  </head>
136
  <body>
137
+ <h1>SDXL Turbo Text to Image</h1>
138
+
139
+ <textarea id="prompt" rows="3" placeholder="Astronaut in a jungle, 8k, cold colors"></textarea>
140
+
141
+ <textarea id="neg" rows="2" placeholder="Negative prompt (optional)"></textarea>
142
+
143
+ <button id="btn" onclick="gen()">Generate</button>
144
+
145
+ <div id="status"></div>
146
+ <img id="result"/>
147
+
148
+ <script>
149
+ async function gen() {
150
+ const btn = document.getElementById("btn");
151
+ const status = document.getElementById("status");
152
+ const img = document.getElementById("result");
153
+
154
+ const prompt = document.getElementById("prompt").value;
155
+ const neg = document.getElementById("neg").value;
156
+
157
+ if (!prompt.trim()) {
158
+ status.textContent = "Please enter a prompt.";
159
+ return;
160
+ }
161
+
162
+ btn.disabled = true;
163
+ status.textContent = "Generating...";
164
+ img.src = "";
165
+
166
+ const res = await fetch("/api/generate", {
167
+ method: "POST",
168
+ headers: { "Content-Type": "application/json" },
169
+ body: JSON.stringify({ prompt, negative_prompt: neg })
170
+ });
171
+
172
+ const j = await res.json();
173
+
174
+ if (j.status !== "success") {
175
+ status.textContent = "Error: " + j.message;
176
+ btn.disabled = false;
177
+ return;
178
+ }
179
+
180
+ img.src = "data:image/png;base64," + j.image_base64;
181
+ status.textContent = "Done. Seed: " + j.seed;
182
+ btn.disabled = false;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  }
184
+ </script>
185
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
  </body>
187
  </html>
188
  """
189
 
190
 
191
+ # -------------------------------------------------------------
192
+ # API Endpoint
193
+ # -------------------------------------------------------------
194
  @app.post("/api/generate")
195
  async def api_generate(request: Request):
196
+
197
  try:
198
  data = await request.json()
199
+ prompt = data.get("prompt", "").strip()
200
+ negative_prompt = data.get("negative_prompt", "").strip()
 
 
 
 
201
  except Exception:
202
  return JSONResponse({"status": "error", "message": "Invalid JSON"}, 400)
203
 
204
+ if not prompt:
205
+ return JSONResponse({"status": "error", "message": "Prompt required"}, 400)
206
+
207
+ width = 768
208
+ height = 432
209
+ steps = 2
210
+ guidance = 0.0 # SDXL Turbo is trained for cfg=0
211
+
212
+ seed = random.randint(0, 2**31 - 1)
213
+
214
  try:
215
+ img = await run_generate(
216
+ prompt, negative_prompt, seed, width, height, steps, guidance
217
+ )
 
 
218
 
219
+ buf = io.BytesIO()
220
+ img.save(buf, format="PNG")
221
+ encoded = base64.b64encode(buf.getvalue()).decode()
222
+
223
+ return JSONResponse(
224
+ {
225
+ "status": "success",
226
+ "image_base64": encoded,
227
+ "seed": seed,
228
+ "width": width,
229
+ "height": height,
230
+ }
231
+ )
232
 
233
+ except Exception as e:
234
+ return JSONResponse({"status": "error", "message": str(e)}, 500)
 
235
 
236
 
237
+ # -------------------------------------------------------------
238
+ # Local run
239
+ # -------------------------------------------------------------
240
  if __name__ == "__main__":
241
  import uvicorn
242
+
243
+ uvicorn.run(app, host="0.0.0.0", port=7860)