File size: 4,541 Bytes
4cb20ce
322b854
1703ae8
7a45985
a30b23b
322b854
a45e639
9811fff
4cb20ce
a30b23b
ff1da03
4cb20ce
 
 
 
322b854
a30b23b
4cb20ce
 
 
a30b23b
322b854
4cb20ce
d3d1d08
4cb20ce
a45e639
 
4cb20ce
322b854
5780526
4cb20ce
a30b23b
4cb20ce
 
 
586b87e
a30b23b
4cb20ce
 
ab01cea
 
4cb20ce
586b87e
4cb20ce
 
a45e639
4cb20ce
c2b6ac4
a30b23b
4cb20ce
1703ae8
4cb20ce
1703ae8
 
a30b23b
1703ae8
 
 
 
c2b6ac4
0b3edc8
 
3579187
4cb20ce
 
 
 
 
 
 
 
 
c0d6943
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0b3edc8
 
 
322b854
 
 
4cb20ce
 
322b854
4cb20ce
322b854
4cb20ce
322b854
a902076
4cb20ce
 
 
 
 
7a45985
a45e639
4cb20ce
9541138
afc500c
4cb20ce
3579187
83180c7
a45e639
4cb20ce
afc500c
a45e639
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import os, io, base64, asyncio, torch, spaces
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse, JSONResponse
from diffusers import FluxPipeline
from PIL import Image
from concurrent.futures import ThreadPoolExecutor

HF_TOKEN = os.getenv("HF_TOKEN")
BASE_MODEL = "black-forest-labs/FLUX.1-schnell"

_cached = {}
# moderate concurrency so CPU doesn’t choke
executor = ThreadPoolExecutor(max_workers=3)
semaphore = asyncio.Semaphore(3)

def load_pipeline():
    if "flux" in _cached:
        return _cached["flux"]
    print("🔹 Loading FLUX.1-schnell (fast mode)")
    pipe = FluxPipeline.from_pretrained(
        BASE_MODEL,
        torch_dtype=torch.float16,
        use_auth_token=HF_TOKEN,
    ).to("cpu", dtype=torch.float16)
    pipe.enable_attention_slicing()
    pipe.enable_vae_tiling()
    _cached["flux"] = pipe
    return pipe

def generate_image_sync(prompt: str, seed: int = 42):
    pipe = load_pipeline()
    gen = torch.Generator(device="cpu").manual_seed(int(seed))
    # smaller size and steps for speed
    w, h = 768, 432
    image = pipe(
        prompt=prompt,
        width=w,
        height=h,
        num_inference_steps=4,
        guidance_scale=3,
        generator=gen,
    ).images[0]
    # slight upscale back to 960×540 to keep output clear
    return image.resize((960, 540), Image.BICUBIC)

async def generate_image_async(prompt, seed):
    async with semaphore:
        loop = asyncio.get_running_loop()
        return await loop.run_in_executor(executor, generate_image_sync, prompt, seed)

app = FastAPI(title="FLUX Fast API", version="3.1")
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

@app.get("/", response_class=HTMLResponse)
def home():
    return """
    <html><head><title>FLUX Fast</title>
    <style>body{font-family:Arial;text-align:center;padding:2rem}
    input,button{margin:.5rem;padding:.6rem;width:300px;border-radius:6px;border:1px solid #ccc}
    button{background:#444;color:#fff}button:hover{background:#333}
    img{margin-top:1rem;max-width:90%;border-radius:12px}</style></head>
    <body><h2>🎨 FLUX Fast Generator</h2>
    <form id='f'><input id='prompt' placeholder='Describe image...' required><br>
    <input id='seed' type='number' value='42'><br>
    <button>Generate</button></form><div id='out'></div>
  <script>
const form = document.getElementById("f");
const promptInput = document.getElementById("prompt");
const seedInput = document.getElementById("seed");
const resultDiv = document.getElementById("out");

form.addEventListener("submit", async (e) => {
  e.preventDefault();
  const prompt = promptInput.value.trim();
  if (!prompt) {
    resultDiv.innerHTML = "<p style='color:red'>❌ Please enter a prompt</p>";
    return;
  }
  resultDiv.innerHTML = "<p>⏳ Generating...</p>";
  const payload = {
    prompt: prompt,
    seed: parseInt(seedInput.value || 42)
  };
  const res = await fetch("/api/generate", {
    method: "POST",
    headers: { "Content-Type": "application/json" },
    body: JSON.stringify(payload)
  });
  const json = await res.json();
  if (json.status === "success") {
    resultDiv.innerHTML = `<img src="data:image/png;base64,${json.image_base64}"/><p>✅ Done!</p>`;
  } else {
    resultDiv.innerHTML = `<p style='color:red'>❌ ${json.message}</p>`;
  }
});
</script>
</body></html>
    """

@app.post("/api/generate")
async def api_generate(request: Request):
    try:
        data = await request.json()
        prompt = str(data.get("prompt", "")).strip()
        seed = int(data.get("seed", 42))
        if not prompt:
            return JSONResponse({"status": "error", "message": "Prompt required"}, 400)
    except Exception:
        return JSONResponse({"status": "error", "message": "Invalid JSON"}, 400)

    try:
        image = await generate_image_async(prompt, seed)
        buf = io.BytesIO()
        image.save(buf, format="PNG")
        img64 = base64.b64encode(buf.getvalue()).decode("utf-8")
        return JSONResponse({"status": "success", "prompt": prompt, "image_base64": img64})
    except Exception as e:
        print(f"❌ Error: {e}")
        return JSONResponse({"status": "error", "message": str(e)}, 500)

@spaces.GPU
def keep_alive(): return "ZeroGPU Ready"

if __name__ == "__main__":
    import uvicorn
    print("🚀 Launching Fast FLUX API")
    keep_alive()
    uvicorn.run(app, host="0.0.0.0", port=7860)