# app.py
import os
import uuid
import shutil
import logging
import requests
import asyncio
import time
from typing import Optional, Dict, Any
from fastapi import FastAPI, UploadFile, File, HTTPException, Query, Form
from fastapi.responses import FileResponse, HTMLResponse, JSONResponse
from huggingface_hub import login
from app.utils import run_inference
# --- Configuration / env ---
hf_token = os.environ.get("HF_TOKEN")
if hf_token:
try:
login(token=hf_token)
except Exception:
# Non-fatal if login fails in some deployments
pass
TMP_DIR = os.environ.get("TMP_DIR", "/app/tmp")
os.makedirs(TMP_DIR, exist_ok=True)
# Logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("stable-fast-3d-api")
app = FastAPI(title="Stable Fast 3D API (Background Jobs)")
# In-memory job registry
# Structure:
# JOBS[request_id] = {
# "status": "pending" | "running" | "done" | "error",
# "input_path": "...",
# "output_dir": "...",
# "glb_path": Optional[str],
# "error": Optional[str],
# "created_at": float,
# "started_at": Optional[float],
# "finished_at": Optional[float],
# }
JOBS: Dict[str, Dict[str, Any]] = {}
JOBS_LOCK = asyncio.Lock()
# -------------------------
# Utility helpers
# -------------------------
def _save_upload_file(upload_file: UploadFile, dest_path: str) -> None:
with open(dest_path, "wb") as f:
shutil.copyfileobj(upload_file.file, f)
upload_file.file.close()
def _download_to_file(url: str, dest_path: str, timeout: int = 30) -> None:
resp = requests.get(url, stream=True, timeout=timeout)
if resp.status_code != 200:
raise HTTPException(status_code=400, detail=f"Failed to download image: status {resp.status_code}")
with open(dest_path, "wb") as f:
for chunk in resp.iter_content(chunk_size=8192):
if not chunk:
continue
f.write(chunk)
def _find_glb_in_dir(output_dir: str) -> Optional[str]:
for root, _, files in os.walk(output_dir):
for fn in files:
if fn.lower().endswith(".glb"):
return os.path.join(root, fn)
return None
async def _set_job_field(job_id: str, key: str, value):
async with JOBS_LOCK:
if job_id in JOBS:
JOBS[job_id][key] = value
async def _get_job(job_id: str):
async with JOBS_LOCK:
return JOBS.get(job_id)
# -------------------------
# Background worker
# -------------------------
async def _background_run_inference(job_id: str):
"""Runs run_inference in a thread to avoid blocking the event loop."""
job = await _get_job(job_id)
if not job:
logger.error("Job not found when starting background task: %s", job_id)
return
input_path = job["input_path"]
output_dir = job["output_dir"]
logger.info("[%s] Background job starting. input=%s output=%s", job_id, input_path, output_dir)
await _set_job_field(job_id, "status", "running")
await _set_job_field(job_id, "started_at", time.time())
try:
# run_inference is synchronous / heavy — move to thread
glb_path = await asyncio.to_thread(run_inference, input_path, output_dir)
# If run_inference returned None or not a path, try to discover a .glb
if not glb_path or not os.path.exists(glb_path):
found = _find_glb_in_dir(output_dir)
if found:
glb_path = found
if not glb_path or not os.path.exists(glb_path):
# List files for debugging
listing = []
for root, _, files in os.walk(output_dir):
for fn in files:
listing.append(os.path.join(root, fn))
raise RuntimeError(f"GLB not produced. output_dir listing: {listing}")
# Mark success
await _set_job_field(job_id, "glb_path", glb_path)
await _set_job_field(job_id, "status", "done")
await _set_job_field(job_id, "finished_at", time.time())
logger.info("[%s] Background job finished successfully. glb=%s", job_id, glb_path)
except Exception as e:
logger.exception("[%s] Background inference failed: %s", job_id, e)
await _set_job_field(job_id, "status", "error")
await _set_job_field(job_id, "error", str(e))
await _set_job_field(job_id, "finished_at", time.time())
# -------------------------
# Embedded UI root (polling-based)
# -------------------------
@app.get("/", response_class=HTMLResponse)
async def root_ui():
html = """
Stable Fast 3D API — Background Jobs
Stable Fast 3D API — Background Jobs
Upload an image or paste an image URL to generate a 3D model (GLB). The job runs server-side and continues even if you close this page.
Status: idle
Job ID:
Waiting...
"""
return HTMLResponse(content=html, status_code=200)
# -------------------------
# API: Start job (non-blocking)
# -------------------------
@app.post("/generate-3d/")
async def generate_3d_start(
image: Optional[UploadFile] = File(None),
image_url: Optional[str] = Form(None),
):
"""
Start a background job to generate a 3D model.
Returns JSON: { "id": "", "status_url": "/status/", "download_url": "/download/" }
"""
request_id = str(uuid.uuid4())
input_path = os.path.join(TMP_DIR, f"{request_id}.png")
output_dir = os.path.join(TMP_DIR, f"{request_id}_output")
os.makedirs(output_dir, exist_ok=True)
# Save input
try:
if image is not None:
_save_upload_file(image, input_path)
elif image_url:
_download_to_file(image_url, input_path, timeout=30)
else:
raise HTTPException(status_code=400, detail="Either image or image_url must be provided")
except HTTPException:
raise
except Exception as e:
logger.exception("Failed to save input for job %s: %s", request_id, e)
raise HTTPException(status_code=500, detail=f"Failed to save input: {e}")
# Register job (pending)
async with JOBS_LOCK:
JOBS[request_id] = {
"status": "pending",
"input_path": input_path,
"output_dir": output_dir,
"glb_path": None,
"error": None,
"created_at": time.time(),
"started_at": None,
"finished_at": None,
}
# Kick off background task (does not block the request)
asyncio.create_task(_background_run_inference(request_id))
logger.info("Started background job %s", request_id)
return JSONResponse({
"id": request_id,
"status_url": f"/status/{request_id}",
"download_url": f"/download/{request_id}",
})
# -------------------------
# API: Check status
# -------------------------
@app.get("/status/{job_id}")
async def job_status(job_id: str):
job = await _get_job(job_id)
if not job:
raise HTTPException(status_code=404, detail="Job not found")
# return the public fields
return JSONResponse({
"id": job_id,
"status": job["status"],
"glb_path": bool(job.get("glb_path")),
"error": job.get("error"),
"created_at": job.get("created_at"),
"started_at": job.get("started_at"),
"finished_at": job.get("finished_at"),
})
# -------------------------
# API: Download result (if ready)
# -------------------------
@app.get("/download/{job_id}")
async def download_result(job_id: str):
job = await _get_job(job_id)
if not job:
raise HTTPException(status_code=404, detail="Job not found")
if job["status"] != "done" or not job.get("glb_path"):
# Not ready
raise HTTPException(status_code=404, detail="Result not ready")
glb_path = job["glb_path"]
if not os.path.exists(glb_path):
raise HTTPException(status_code=404, detail="GLB file missing on disk")
# Return FileResponse without deleting it (user must call DELETE to remove)
return FileResponse(path=glb_path, media_type="model/gltf-binary", filename=os.path.basename(glb_path))
# -------------------------
# API: Delete job & files (manual)
# -------------------------
@app.delete("/delete/{job_id}")
async def delete_job(job_id: str):
job = await _get_job(job_id)
if not job:
raise HTTPException(status_code=404, detail="Job not found")
# Remove files
input_path = job.get("input_path")
output_dir = job.get("output_dir")
glb_path = job.get("glb_path")
errors = []
try:
if input_path and os.path.exists(input_path):
os.remove(input_path)
except Exception as e:
errors.append(f"input removal error: {e}")
try:
if output_dir and os.path.exists(output_dir):
shutil.rmtree(output_dir, ignore_errors=True)
except Exception as e:
errors.append(f"output dir removal error: {e}")
# Remove job entry
async with JOBS_LOCK:
JOBS.pop(job_id, None)
if errors:
logger.warning("Delete job %s completed with errors: %s", job_id, errors)
return JSONResponse({"deleted": True, "errors": errors})
return JSONResponse({"deleted": True})
# -------------------------
# API: List jobs (optional)
# -------------------------
@app.get("/jobs")
async def list_jobs():
async with JOBS_LOCK:
out = {
jid: {
"status": j["status"],
"created_at": j["created_at"],
"started_at": j["started_at"],
"finished_at": j["finished_at"],
"has_glb": bool(j.get("glb_path")),
}
for jid, j in JOBS.items()
}
return JSONResponse(out)