videopix commited on
Commit
5780526
Β·
verified Β·
1 Parent(s): 0fd9ca7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -36
app.py CHANGED
@@ -1,20 +1,19 @@
1
- import spaces # must come first before any torch import
2
  import io
3
  import base64
 
4
  import asyncio
5
  from fastapi import FastAPI, Request, Form
6
  from fastapi.responses import HTMLResponse, JSONResponse
7
  from fastapi.middleware.cors import CORSMiddleware
8
  from PIL import Image
9
- import torch
10
  from diffusers import StableDiffusionPipeline
11
- from peft import PeftModel # required for LoRA
12
 
13
  # ======================
14
  # App Setup
15
  # ======================
16
  app = FastAPI(title="Ghibli + Anime Image Generator (ZeroGPU CPU mode)")
17
-
18
  app.add_middleware(
19
  CORSMiddleware,
20
  allow_origins=["*"],
@@ -27,49 +26,42 @@ app.add_middleware(
27
  # ======================
28
  device = "cpu"
29
  dtype = torch.float32
 
30
 
31
  print("πŸ”Ή Loading base SDXL model (CPU)...")
32
- base_model = "stabilityai/stable-diffusion-xl-base-1.0"
33
-
34
- # Load once and reuse
35
  pipe_ghibli = StableDiffusionPipeline.from_pretrained(
36
- base_model,
37
  torch_dtype=dtype,
38
- safety_checker=None, # You can re-enable this if you prefer
39
  )
40
  pipe_ghibli.to(device)
41
 
42
  pipe_anime = StableDiffusionPipeline.from_pretrained(
43
- base_model,
44
  torch_dtype=dtype,
45
- safety_checker=None,
46
  )
47
  pipe_anime.to(device)
48
 
49
- # --- Apply LoRA weights manually via PEFT ---
50
- try:
51
- print("🎨 Loading Ghibli LoRA...")
52
- PeftModel.from_pretrained(pipe_ghibli.unet, "./studioghibli_flux_r32-v2.safetensors")
53
- print("βœ… Ghibli LoRA loaded.")
54
- except Exception as e:
55
- print(f"⚠️ Ghibli LoRA load failed: {e}")
56
-
57
- try:
58
- print("🎨 Loading Anime LoRA...")
59
- PeftModel.from_pretrained(pipe_anime.unet, "./Flux_1_Dev_LoRA_AestheticAnime.safetensors")
60
- print("βœ… Anime LoRA loaded.")
61
- except Exception as e:
62
- print(f"⚠️ Anime LoRA load failed: {e}")
63
 
64
  # ======================
65
  # Helper functions
66
  # ======================
67
-
68
  async def generate_image(prompt: str, style: str, seed: int):
69
  generator = torch.Generator(device=device).manual_seed(seed)
70
  pipe = pipe_anime if style.lower() == "anime" else pipe_ghibli
71
-
72
  print(f"🎨 Generating {style} image for: {prompt}")
 
73
  image = pipe(
74
  prompt=f"{prompt}, {style} style, cinematic lighting, high quality",
75
  num_inference_steps=30,
@@ -94,9 +86,8 @@ def image_to_base64(img: Image.Image) -> str:
94
 
95
 
96
  # ======================
97
- # API Routes
98
  # ======================
99
-
100
  @app.get("/", response_class=HTMLResponse)
101
  def home():
102
  return """
@@ -114,7 +105,7 @@ def home():
114
 
115
 
116
  @app.post("/api/generate")
117
- async def api_generate(request: Request, prompt: str = Form(...), style: str = Form(...), seed: int = Form(42)):
118
  try:
119
  print(f"πŸ“© Received: {prompt} | Style={style} | Seed={seed}")
120
  image = await generate_image(prompt, style, seed)
@@ -126,9 +117,6 @@ async def api_generate(request: Request, prompt: str = Form(...), style: str = F
126
  return JSONResponse({"status": "error", "message": str(e)}, status_code=500)
127
 
128
 
129
- # ======================
130
- # ZeroGPU keep-alive
131
- # ======================
132
  @spaces.GPU
133
  def keep_alive():
134
  """Required for ZeroGPU Spaces β€” keeps container active."""
@@ -136,9 +124,6 @@ def keep_alive():
136
  return "OK"
137
 
138
 
139
- # ======================
140
- # Run the app
141
- # ======================
142
  if __name__ == "__main__":
143
  import uvicorn
144
  print("πŸš€ Launching FastAPI on port 7860 (ZeroGPU mode)")
 
1
+ import spaces # must come first before any torch or CUDA import
2
  import io
3
  import base64
4
+ import torch
5
  import asyncio
6
  from fastapi import FastAPI, Request, Form
7
  from fastapi.responses import HTMLResponse, JSONResponse
8
  from fastapi.middleware.cors import CORSMiddleware
9
  from PIL import Image
 
10
  from diffusers import StableDiffusionPipeline
11
+ from safetensors.torch import load_file
12
 
13
  # ======================
14
  # App Setup
15
  # ======================
16
  app = FastAPI(title="Ghibli + Anime Image Generator (ZeroGPU CPU mode)")
 
17
  app.add_middleware(
18
  CORSMiddleware,
19
  allow_origins=["*"],
 
26
  # ======================
27
  device = "cpu"
28
  dtype = torch.float32
29
+ BASE_MODEL = "stabilityai/stable-diffusion-xl-base-1.0"
30
 
31
  print("πŸ”Ή Loading base SDXL model (CPU)...")
 
 
 
32
  pipe_ghibli = StableDiffusionPipeline.from_pretrained(
33
+ BASE_MODEL,
34
  torch_dtype=dtype,
 
35
  )
36
  pipe_ghibli.to(device)
37
 
38
  pipe_anime = StableDiffusionPipeline.from_pretrained(
39
+ BASE_MODEL,
40
  torch_dtype=dtype,
 
41
  )
42
  pipe_anime.to(device)
43
 
44
+ # --- Load LoRA weights directly ---
45
+ def apply_lora(pipe, path):
46
+ try:
47
+ print(f"🎨 Applying LoRA: {path}")
48
+ lora_weights = load_file(path)
49
+ missing, unexpected = pipe.unet.load_state_dict(lora_weights, strict=False)
50
+ print(f"βœ… LoRA loaded successfully. Missing: {len(missing)}, Unexpected: {len(unexpected)}")
51
+ except Exception as e:
52
+ print(f"⚠️ Failed to load LoRA ({path}): {e}")
53
+
54
+ apply_lora(pipe_ghibli, "./studioghibli_flux_r32-v2.safetensors")
55
+ apply_lora(pipe_anime, "./Flux_1_Dev_LoRA_AestheticAnime.safetensors")
 
 
56
 
57
  # ======================
58
  # Helper functions
59
  # ======================
 
60
  async def generate_image(prompt: str, style: str, seed: int):
61
  generator = torch.Generator(device=device).manual_seed(seed)
62
  pipe = pipe_anime if style.lower() == "anime" else pipe_ghibli
 
63
  print(f"🎨 Generating {style} image for: {prompt}")
64
+
65
  image = pipe(
66
  prompt=f"{prompt}, {style} style, cinematic lighting, high quality",
67
  num_inference_steps=30,
 
86
 
87
 
88
  # ======================
89
+ # Routes
90
  # ======================
 
91
  @app.get("/", response_class=HTMLResponse)
92
  def home():
93
  return """
 
105
 
106
 
107
  @app.post("/api/generate")
108
+ async def api_generate(prompt: str = Form(...), style: str = Form("Ghibli"), seed: int = Form(42)):
109
  try:
110
  print(f"πŸ“© Received: {prompt} | Style={style} | Seed={seed}")
111
  image = await generate_image(prompt, style, seed)
 
117
  return JSONResponse({"status": "error", "message": str(e)}, status_code=500)
118
 
119
 
 
 
 
120
  @spaces.GPU
121
  def keep_alive():
122
  """Required for ZeroGPU Spaces β€” keeps container active."""
 
124
  return "OK"
125
 
126
 
 
 
 
127
  if __name__ == "__main__":
128
  import uvicorn
129
  print("πŸš€ Launching FastAPI on port 7860 (ZeroGPU mode)")