Gibhili / app.py
videopix's picture
Update app.py
2cc5994 verified
raw
history blame
9.81 kB
import os
import io
import base64
import asyncio
import spaces
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse, JSONResponse
from concurrent.futures import ThreadPoolExecutor
from PIL import Image
HF_TOKEN = os.getenv("HF_TOKEN")
BASE_MODEL = "black-forest-labs/FLUX.1-schnell"
# concurrency
executor = ThreadPoolExecutor(max_workers=3)
semaphore = asyncio.Semaphore(3)
# --------------------------------------------------------
# IMPORTANT: no torch.cuda calls, no GPU detection, no
# pipeline loading here. Only CPU-safe imports.
# --------------------------------------------------------
from diffusers import FluxPipeline
import torch
# --------------------------------------------------------
# GPU function: runs in a separate GPU worker process.
# Full model load + inference must live here.
# --------------------------------------------------------
@spaces.GPU
def gpu_generate(prompt: str, seed: int):
print("⚡ ZeroGPU worker starting model load + inference")
pipe = FluxPipeline.from_pretrained(
BASE_MODEL,
torch_dtype=torch.float16, # safe on GPU worker
use_auth_token=HF_TOKEN,
low_cpu_mem_usage=True
).to("cuda")
try:
pipe.enable_attention_slicing()
pipe.enable_vae_tiling()
pipe.enable_xformers_memory_efficient_attention()
except Exception:
pass
generator = torch.Generator(device="cuda").manual_seed(seed)
img = pipe(
prompt=prompt,
width=768,
height=432,
num_inference_steps=6,
guidance_scale=2.5,
generator=generator,
).images[0]
img = img.resize((960, 540), Image.BICUBIC)
buf = io.BytesIO()
img.save(buf, format="PNG")
return base64.b64encode(buf.getvalue()).decode("utf-8")
# --------------------------------------------------------
# Async wrapper to allow multiple simultaneous requests
# --------------------------------------------------------
async def generate_image_async(prompt, seed):
async with semaphore:
loop = asyncio.get_running_loop()
return await loop.run_in_executor(
executor,
gpu_generate,
prompt,
seed
)
# --------------------------------------------------------
# FastAPI app
# --------------------------------------------------------
app = FastAPI(title="FLUX Fast API", version="3.1")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/", response_class=HTMLResponse)
def home():
return """
<!doctype html>
<html lang="en">
<head>
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width,initial-scale=1" />
<title>FLUX Fast Generator</title>
<style>
:root{font-family:Inter, Roboto, Arial, sans-serif; color:#111}
body{max-width:900px;margin:32px auto;padding:24px;line-height:1.45}
h1{font-size:1.6rem;margin:0 0 12px}
p.lead{color:#444;margin:0 0 18px}
.card{border:1px solid #e6e6e6;border-radius:12px;padding:18px;box-shadow:0 4px 14px rgba(20,20,20,0.03)}
label{display:block;margin:12px 0 6px;font-weight:600}
input[type="text"], input[type="number"], textarea{
width:100%;box-sizing:border-box;padding:10px;border-radius:8px;border:1px solid #d5d5d5;font-size:14px
}
textarea{min-height:100px;resize:vertical}
.row{display:flex;gap:12px;align-items:center;margin-top:12px}
button{padding:10px 16px;border-radius:8px;border:0;background:#111;color:#fff;cursor:pointer}
button.secondary{background:#f3f3f3;color:#111;border:1px solid #ddd}
button:disabled{opacity:0.6;cursor:not-allowed}
.meta{font-size:13px;color:#666;margin-top:8px}
.result{margin-top:18px;text-align:center}
.result img{max-width:100%;border-radius:12px;box-shadow:0 6px 30px rgba(0,0,0,0.06)}
.footer{margin-top:18px;font-size:13px;color:#666;text-align:center}
.progress{display:inline-flex;align-items:center;gap:10px}
.spinner{
width:18px;height:18px;border-radius:50%;border:3px solid rgba(0,0,0,0.08);border-top-color:#111;animation:spin 1s linear infinite
}
@keyframes spin{to{transform:rotate(360deg)}}
.download{display:inline-block;margin-top:8px;padding:8px 12px;border-radius:8px;background:#fff;border:1px solid #ddd;color:#111;text-decoration:none}
</style>
</head>
<body>
<h1>FLUX Fast Generator</h1>
<p class="lead">Enter a prompt and press Generate. The backend runs model inference and returns the generated image.</p>
<div class="card">
<form id="genForm">
<label for="prompt">Prompt</label>
<textarea id="prompt" placeholder="A scene of a futuristic city at golden hour, cinematic lighting, ultra-detailed..." required></textarea>
<div style="display:flex;gap:12px;flex-wrap:wrap;margin-top:8px;">
<div style="flex:1;min-width:160px">
<label for="seed">Seed (optional)</label>
<input id="seed" type="number" value="42" />
</div>
<div style="width:160px">
<label for="steps">Steps</label>
<input id="steps" type="number" value="6" min="1" max="50" />
</div>
<div style="width:160px">
<label for="scale">Guidance</label>
<input id="scale" type="number" step="0.1" value="2.5" min="1" max="20" />
</div>
</div>
<div class="row" style="margin-top:18px">
<button id="genBtn" type="submit">Generate</button>
<button id="clearBtn" type="button" class="secondary">Clear</button>
<div class="meta" id="status" style="margin-left:auto"></div>
</div>
</form>
<div class="result" id="resultArea" aria-live="polite"></div>
</div>
<div class="footer">Tip: keep steps and resolution low for faster results in CPU or cold GPU environments.</div>
<script>
const form = document.getElementById('genForm');
const promptInput = document.getElementById('prompt');
const seedInput = document.getElementById('seed');
const stepsInput = document.getElementById('steps');
const scaleInput = document.getElementById('scale');
const genBtn = document.getElementById('genBtn');
const clearBtn = document.getElementById('clearBtn');
const status = document.getElementById('status');
const resultArea = document.getElementById('resultArea');
clearBtn.addEventListener('click', () => {
promptInput.value = '';
resultArea.innerHTML = '';
status.textContent = '';
});
form.addEventListener('submit', async (e) => {
e.preventDefault();
const prompt = promptInput.value.trim();
if (!prompt) {
status.textContent = 'Please enter a prompt';
return;
}
const payload = {
prompt: prompt,
seed: parseInt(seedInput.value || 42),
num_inference_steps: parseInt(stepsInput.value || 6),
guidance_scale: parseFloat(scaleInput.value || 2.5)
};
// UI state
genBtn.disabled = true;
clearBtn.disabled = true;
status.innerHTML = '<span class="progress"><span class="spinner"></span> Generating...</span>';
resultArea.innerHTML = '';
const start = Date.now();
try {
const res = await fetch('/api/generate', {
method: 'POST',
headers: {'Content-Type': 'application/json'},
body: JSON.stringify(payload)
});
const json = await res.json();
if (!res.ok || json.status !== 'success') {
const msg = json && json.message ? json.message : 'Generation failed';
status.textContent = 'Error: ' + msg;
genBtn.disabled = false;
clearBtn.disabled = false;
return;
}
const took = ((Date.now() - start) / 1000).toFixed(1);
status.textContent = `Done in ${took}s`;
const imgData = 'data:image/png;base64,' + json.image_base64;
const img = document.createElement('img');
img.src = imgData;
img.alt = prompt;
resultArea.appendChild(img);
const dl = document.createElement('a');
dl.href = imgData;
dl.download = 'flux_gen.png';
dl.className = 'download';
dl.textContent = 'Download PNG';
resultArea.appendChild(dl);
} catch (err) {
console.error(err);
status.textContent = 'Network or server error';
} finally {
genBtn.disabled = false;
clearBtn.disabled = false;
}
});
</script>
</body>
</html>
"""
@app.post("/api/generate")
async def api_generate(request: Request):
try:
data = await request.json()
prompt = str(data.get("prompt", "")).strip()
seed = int(data.get("seed", 42))
if not prompt:
return JSONResponse({"status": "error", "message": "Prompt required"}, 400)
except Exception:
return JSONResponse({"status": "error", "message": "Invalid JSON"}, 400)
try:
img64 = await generate_image_async(prompt, seed)
return JSONResponse({"status": "success", "image_base64": img64, "prompt": prompt})
except Exception as e:
print("❌ Error:", e)
return JSONResponse({"status": "error", "message": str(e)}, 500)
@spaces.GPU
def keep_alive():
return "ZeroGPU Ready"
if __name__ == "__main__":
import uvicorn
print("🚀 Launching Fast FLUX API")
keep_alive()
uvicorn.run(app, host="0.0.0.0", port=7860)