Spaces:
Running
Running
File size: 4,541 Bytes
4cb20ce 322b854 1703ae8 7a45985 a30b23b 322b854 a45e639 9811fff 4cb20ce a30b23b ff1da03 4cb20ce 322b854 a30b23b 4cb20ce a30b23b 322b854 4cb20ce d3d1d08 4cb20ce a45e639 4cb20ce 322b854 5780526 4cb20ce a30b23b 4cb20ce 586b87e a30b23b 4cb20ce ab01cea 4cb20ce 586b87e 4cb20ce a45e639 4cb20ce c2b6ac4 a30b23b 4cb20ce 1703ae8 4cb20ce 1703ae8 a30b23b 1703ae8 c2b6ac4 0b3edc8 3579187 4cb20ce c0d6943 0b3edc8 322b854 4cb20ce 322b854 4cb20ce 322b854 4cb20ce 322b854 a902076 4cb20ce 7a45985 a45e639 4cb20ce 9541138 afc500c 4cb20ce 3579187 83180c7 a45e639 4cb20ce afc500c a45e639 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
import os, io, base64, asyncio, torch, spaces
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse, JSONResponse
from diffusers import FluxPipeline
from PIL import Image
from concurrent.futures import ThreadPoolExecutor
HF_TOKEN = os.getenv("HF_TOKEN")
BASE_MODEL = "black-forest-labs/FLUX.1-schnell"
_cached = {}
# moderate concurrency so CPU doesn’t choke
executor = ThreadPoolExecutor(max_workers=3)
semaphore = asyncio.Semaphore(3)
def load_pipeline():
if "flux" in _cached:
return _cached["flux"]
print("🔹 Loading FLUX.1-schnell (fast mode)")
pipe = FluxPipeline.from_pretrained(
BASE_MODEL,
torch_dtype=torch.float16,
use_auth_token=HF_TOKEN,
).to("cpu", dtype=torch.float16)
pipe.enable_attention_slicing()
pipe.enable_vae_tiling()
_cached["flux"] = pipe
return pipe
def generate_image_sync(prompt: str, seed: int = 42):
pipe = load_pipeline()
gen = torch.Generator(device="cpu").manual_seed(int(seed))
# smaller size and steps for speed
w, h = 768, 432
image = pipe(
prompt=prompt,
width=w,
height=h,
num_inference_steps=4,
guidance_scale=3,
generator=gen,
).images[0]
# slight upscale back to 960×540 to keep output clear
return image.resize((960, 540), Image.BICUBIC)
async def generate_image_async(prompt, seed):
async with semaphore:
loop = asyncio.get_running_loop()
return await loop.run_in_executor(executor, generate_image_sync, prompt, seed)
app = FastAPI(title="FLUX Fast API", version="3.1")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/", response_class=HTMLResponse)
def home():
return """
<html><head><title>FLUX Fast</title>
<style>body{font-family:Arial;text-align:center;padding:2rem}
input,button{margin:.5rem;padding:.6rem;width:300px;border-radius:6px;border:1px solid #ccc}
button{background:#444;color:#fff}button:hover{background:#333}
img{margin-top:1rem;max-width:90%;border-radius:12px}</style></head>
<body><h2>🎨 FLUX Fast Generator</h2>
<form id='f'><input id='prompt' placeholder='Describe image...' required><br>
<input id='seed' type='number' value='42'><br>
<button>Generate</button></form><div id='out'></div>
<script>
const form = document.getElementById("f");
const promptInput = document.getElementById("prompt");
const seedInput = document.getElementById("seed");
const resultDiv = document.getElementById("out");
form.addEventListener("submit", async (e) => {
e.preventDefault();
const prompt = promptInput.value.trim();
if (!prompt) {
resultDiv.innerHTML = "<p style='color:red'>❌ Please enter a prompt</p>";
return;
}
resultDiv.innerHTML = "<p>⏳ Generating...</p>";
const payload = {
prompt: prompt,
seed: parseInt(seedInput.value || 42)
};
const res = await fetch("/api/generate", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify(payload)
});
const json = await res.json();
if (json.status === "success") {
resultDiv.innerHTML = `<img src="data:image/png;base64,${json.image_base64}"/><p>✅ Done!</p>`;
} else {
resultDiv.innerHTML = `<p style='color:red'>❌ ${json.message}</p>`;
}
});
</script>
</body></html>
"""
@app.post("/api/generate")
async def api_generate(request: Request):
try:
data = await request.json()
prompt = str(data.get("prompt", "")).strip()
seed = int(data.get("seed", 42))
if not prompt:
return JSONResponse({"status": "error", "message": "Prompt required"}, 400)
except Exception:
return JSONResponse({"status": "error", "message": "Invalid JSON"}, 400)
try:
image = await generate_image_async(prompt, seed)
buf = io.BytesIO()
image.save(buf, format="PNG")
img64 = base64.b64encode(buf.getvalue()).decode("utf-8")
return JSONResponse({"status": "success", "prompt": prompt, "image_base64": img64})
except Exception as e:
print(f"❌ Error: {e}")
return JSONResponse({"status": "error", "message": str(e)}, 500)
@spaces.GPU
def keep_alive(): return "ZeroGPU Ready"
if __name__ == "__main__":
import uvicorn
print("🚀 Launching Fast FLUX API")
keep_alive()
uvicorn.run(app, host="0.0.0.0", port=7860)
|