Spaces:
Running
Running
| import os | |
| import io | |
| import base64 | |
| import asyncio | |
| import spaces | |
| from fastapi import FastAPI, Request | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import HTMLResponse, JSONResponse | |
| from concurrent.futures import ThreadPoolExecutor | |
| from PIL import Image | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| BASE_MODEL = "black-forest-labs/FLUX.1-schnell" | |
| # concurrency | |
| executor = ThreadPoolExecutor(max_workers=3) | |
| semaphore = asyncio.Semaphore(3) | |
| # -------------------------------------------------------- | |
| # IMPORTANT: no torch.cuda calls, no GPU detection, no | |
| # pipeline loading here. Only CPU-safe imports. | |
| # -------------------------------------------------------- | |
| from diffusers import FluxPipeline | |
| import torch | |
| # -------------------------------------------------------- | |
| # GPU function: runs in a separate GPU worker process. | |
| # Full model load + inference must live here. | |
| # -------------------------------------------------------- | |
| def gpu_generate(prompt: str, seed: int): | |
| print("⚡ ZeroGPU worker starting model load + inference") | |
| pipe = FluxPipeline.from_pretrained( | |
| BASE_MODEL, | |
| torch_dtype=torch.float16, # safe on GPU worker | |
| use_auth_token=HF_TOKEN, | |
| low_cpu_mem_usage=True | |
| ).to("cuda") | |
| try: | |
| pipe.enable_attention_slicing() | |
| pipe.enable_vae_tiling() | |
| pipe.enable_xformers_memory_efficient_attention() | |
| except Exception: | |
| pass | |
| generator = torch.Generator(device="cuda").manual_seed(seed) | |
| img = pipe( | |
| prompt=prompt, | |
| width=768, | |
| height=432, | |
| num_inference_steps=6, | |
| guidance_scale=2.5, | |
| generator=generator, | |
| ).images[0] | |
| img = img.resize((960, 540), Image.BICUBIC) | |
| buf = io.BytesIO() | |
| img.save(buf, format="PNG") | |
| return base64.b64encode(buf.getvalue()).decode("utf-8") | |
| # -------------------------------------------------------- | |
| # Async wrapper to allow multiple simultaneous requests | |
| # -------------------------------------------------------- | |
| async def generate_image_async(prompt, seed): | |
| async with semaphore: | |
| loop = asyncio.get_running_loop() | |
| return await loop.run_in_executor( | |
| executor, | |
| gpu_generate, | |
| prompt, | |
| seed | |
| ) | |
| # -------------------------------------------------------- | |
| # FastAPI app | |
| # -------------------------------------------------------- | |
| app = FastAPI(title="FLUX Fast API", version="3.1") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| def home(): | |
| return """ | |
| <!doctype html> | |
| <html lang="en"> | |
| <head> | |
| <meta charset="utf-8" /> | |
| <meta name="viewport" content="width=device-width,initial-scale=1" /> | |
| <title>FLUX Fast Generator</title> | |
| <style> | |
| :root{font-family:Inter, Roboto, Arial, sans-serif; color:#111} | |
| body{max-width:900px;margin:32px auto;padding:24px;line-height:1.45} | |
| h1{font-size:1.6rem;margin:0 0 12px} | |
| p.lead{color:#444;margin:0 0 18px} | |
| .card{border:1px solid #e6e6e6;border-radius:12px;padding:18px;box-shadow:0 4px 14px rgba(20,20,20,0.03)} | |
| label{display:block;margin:12px 0 6px;font-weight:600} | |
| input[type="text"], input[type="number"], textarea{ | |
| width:100%;box-sizing:border-box;padding:10px;border-radius:8px;border:1px solid #d5d5d5;font-size:14px | |
| } | |
| textarea{min-height:100px;resize:vertical} | |
| .row{display:flex;gap:12px;align-items:center;margin-top:12px} | |
| button{padding:10px 16px;border-radius:8px;border:0;background:#111;color:#fff;cursor:pointer} | |
| button.secondary{background:#f3f3f3;color:#111;border:1px solid #ddd} | |
| button:disabled{opacity:0.6;cursor:not-allowed} | |
| .meta{font-size:13px;color:#666;margin-top:8px} | |
| .result{margin-top:18px;text-align:center} | |
| .result img{max-width:100%;border-radius:12px;box-shadow:0 6px 30px rgba(0,0,0,0.06)} | |
| .footer{margin-top:18px;font-size:13px;color:#666;text-align:center} | |
| .progress{display:inline-flex;align-items:center;gap:10px} | |
| .spinner{ | |
| width:18px;height:18px;border-radius:50%;border:3px solid rgba(0,0,0,0.08);border-top-color:#111;animation:spin 1s linear infinite | |
| } | |
| @keyframes spin{to{transform:rotate(360deg)}} | |
| .download{display:inline-block;margin-top:8px;padding:8px 12px;border-radius:8px;background:#fff;border:1px solid #ddd;color:#111;text-decoration:none} | |
| </style> | |
| </head> | |
| <body> | |
| <h1>FLUX Fast Generator</h1> | |
| <p class="lead">Enter a prompt and press Generate. The backend runs model inference and returns the generated image.</p> | |
| <div class="card"> | |
| <form id="genForm"> | |
| <label for="prompt">Prompt</label> | |
| <textarea id="prompt" placeholder="A scene of a futuristic city at golden hour, cinematic lighting, ultra-detailed..." required></textarea> | |
| <div style="display:flex;gap:12px;flex-wrap:wrap;margin-top:8px;"> | |
| <div style="flex:1;min-width:160px"> | |
| <label for="seed">Seed (optional)</label> | |
| <input id="seed" type="number" value="42" /> | |
| </div> | |
| <div style="width:160px"> | |
| <label for="steps">Steps</label> | |
| <input id="steps" type="number" value="6" min="1" max="50" /> | |
| </div> | |
| <div style="width:160px"> | |
| <label for="scale">Guidance</label> | |
| <input id="scale" type="number" step="0.1" value="2.5" min="1" max="20" /> | |
| </div> | |
| </div> | |
| <div class="row" style="margin-top:18px"> | |
| <button id="genBtn" type="submit">Generate</button> | |
| <button id="clearBtn" type="button" class="secondary">Clear</button> | |
| <div class="meta" id="status" style="margin-left:auto"></div> | |
| </div> | |
| </form> | |
| <div class="result" id="resultArea" aria-live="polite"></div> | |
| </div> | |
| <div class="footer">Tip: keep steps and resolution low for faster results in CPU or cold GPU environments.</div> | |
| <script> | |
| const form = document.getElementById('genForm'); | |
| const promptInput = document.getElementById('prompt'); | |
| const seedInput = document.getElementById('seed'); | |
| const stepsInput = document.getElementById('steps'); | |
| const scaleInput = document.getElementById('scale'); | |
| const genBtn = document.getElementById('genBtn'); | |
| const clearBtn = document.getElementById('clearBtn'); | |
| const status = document.getElementById('status'); | |
| const resultArea = document.getElementById('resultArea'); | |
| clearBtn.addEventListener('click', () => { | |
| promptInput.value = ''; | |
| resultArea.innerHTML = ''; | |
| status.textContent = ''; | |
| }); | |
| form.addEventListener('submit', async (e) => { | |
| e.preventDefault(); | |
| const prompt = promptInput.value.trim(); | |
| if (!prompt) { | |
| status.textContent = 'Please enter a prompt'; | |
| return; | |
| } | |
| const payload = { | |
| prompt: prompt, | |
| seed: parseInt(seedInput.value || 42), | |
| num_inference_steps: parseInt(stepsInput.value || 6), | |
| guidance_scale: parseFloat(scaleInput.value || 2.5) | |
| }; | |
| // UI state | |
| genBtn.disabled = true; | |
| clearBtn.disabled = true; | |
| status.innerHTML = '<span class="progress"><span class="spinner"></span> Generating...</span>'; | |
| resultArea.innerHTML = ''; | |
| const start = Date.now(); | |
| try { | |
| const res = await fetch('/api/generate', { | |
| method: 'POST', | |
| headers: {'Content-Type': 'application/json'}, | |
| body: JSON.stringify(payload) | |
| }); | |
| const json = await res.json(); | |
| if (!res.ok || json.status !== 'success') { | |
| const msg = json && json.message ? json.message : 'Generation failed'; | |
| status.textContent = 'Error: ' + msg; | |
| genBtn.disabled = false; | |
| clearBtn.disabled = false; | |
| return; | |
| } | |
| const took = ((Date.now() - start) / 1000).toFixed(1); | |
| status.textContent = `Done in ${took}s`; | |
| const imgData = 'data:image/png;base64,' + json.image_base64; | |
| const img = document.createElement('img'); | |
| img.src = imgData; | |
| img.alt = prompt; | |
| resultArea.appendChild(img); | |
| const dl = document.createElement('a'); | |
| dl.href = imgData; | |
| dl.download = 'flux_gen.png'; | |
| dl.className = 'download'; | |
| dl.textContent = 'Download PNG'; | |
| resultArea.appendChild(dl); | |
| } catch (err) { | |
| console.error(err); | |
| status.textContent = 'Network or server error'; | |
| } finally { | |
| genBtn.disabled = false; | |
| clearBtn.disabled = false; | |
| } | |
| }); | |
| </script> | |
| </body> | |
| </html> | |
| """ | |
| async def api_generate(request: Request): | |
| try: | |
| data = await request.json() | |
| prompt = str(data.get("prompt", "")).strip() | |
| seed = int(data.get("seed", 42)) | |
| if not prompt: | |
| return JSONResponse({"status": "error", "message": "Prompt required"}, 400) | |
| except Exception: | |
| return JSONResponse({"status": "error", "message": "Invalid JSON"}, 400) | |
| try: | |
| img64 = await generate_image_async(prompt, seed) | |
| return JSONResponse({"status": "success", "image_base64": img64, "prompt": prompt}) | |
| except Exception as e: | |
| print("❌ Error:", e) | |
| return JSONResponse({"status": "error", "message": str(e)}, 500) | |
| def keep_alive(): | |
| return "ZeroGPU Ready" | |
| if __name__ == "__main__": | |
| import uvicorn | |
| print("🚀 Launching Fast FLUX API") | |
| keep_alive() | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |