Update app.py
Browse files
app.py
CHANGED
|
@@ -2,7 +2,7 @@ import subprocess
|
|
| 2 |
subprocess.run(['sh', './spaces.sh'])
|
| 3 |
|
| 4 |
import os
|
| 5 |
-
|
| 6 |
os.environ['PYTORCH_NVML_BASED_CUDA_CHECK'] = '1'
|
| 7 |
os.environ['TORCH_LINALG_PREFER_CUSOLVER'] = '1'
|
| 8 |
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True,pinned_use_background_threads:True'
|
|
@@ -17,7 +17,6 @@ import datetime
|
|
| 17 |
import threading
|
| 18 |
import io
|
| 19 |
|
| 20 |
-
# --- New GCS Imports ---
|
| 21 |
from google.oauth2 import service_account
|
| 22 |
from google.cloud import storage
|
| 23 |
|
|
@@ -27,9 +26,8 @@ import torch
|
|
| 27 |
def install_flashattn():
|
| 28 |
subprocess.run(['sh', './flashattn.sh'])
|
| 29 |
|
| 30 |
-
|
| 31 |
|
| 32 |
-
# Torch performance settings
|
| 33 |
torch.backends.cuda.matmul.allow_tf32 = False
|
| 34 |
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False
|
| 35 |
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
|
|
@@ -45,13 +43,10 @@ from PIL import Image
|
|
| 45 |
from image_gen_aux import UpscaleWithModel
|
| 46 |
|
| 47 |
|
| 48 |
-
# --- GCS Configuration ---
|
| 49 |
-
# Make sure to set these secrets in your Hugging Face Space settings
|
| 50 |
GCS_BUCKET_NAME = os.getenv("GCS_BUCKET_NAME")
|
| 51 |
GCS_SA_KEY = os.getenv("GCS_SA_KEY") # The full JSON key content as a string
|
| 52 |
-
|
| 53 |
-
# Initialize GCS client if credentials are available
|
| 54 |
gcs_client = None
|
|
|
|
| 55 |
if GCS_SA_KEY and GCS_BUCKET_NAME:
|
| 56 |
try:
|
| 57 |
credentials_info = eval(GCS_SA_KEY) # Using eval is safe here if you trust the secret source
|
|
@@ -79,6 +74,50 @@ def upload_to_gcs(image_object, filename):
|
|
| 79 |
|
| 80 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 81 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
@spaces.GPU(duration=120)
|
| 83 |
def compile_transformer():
|
| 84 |
with spaces.aoti_capture(pipe.transformer) as call:
|
|
@@ -106,13 +145,18 @@ def load_model():
|
|
| 106 |
|
| 107 |
pipe, upscaler_2 = load_model()
|
| 108 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
compiled_transformer = compile_transformer()
|
| 110 |
spaces.aoti_apply(compiled_transformer, pipe.transformer)
|
| 111 |
|
| 112 |
MAX_SEED = np.iinfo(np.int32).max
|
| 113 |
MAX_IMAGE_SIZE = 4096
|
| 114 |
|
| 115 |
-
|
| 116 |
@spaces.GPU(duration=45)
|
| 117 |
def generate_images_30(prompt, neg_prompt_1, neg_prompt_2, neg_prompt_3, width, height, guidance, steps, progress=gr.Progress(track_tqdm=True)):
|
| 118 |
seed = random.randint(0, MAX_SEED)
|
|
@@ -234,6 +278,7 @@ css = """
|
|
| 234 |
#col-container {margin: 0 auto;max-width: 640px;}
|
| 235 |
body{background-color: blue;}
|
| 236 |
"""
|
|
|
|
| 237 |
with gr.Blocks(theme=gr.themes.Origin(), css=css) as demo:
|
| 238 |
with gr.Column(elem_id="col-container"):
|
| 239 |
gr.Markdown(" # StableDiffusion 3.5 Large with UltraReal lora test")
|
|
@@ -310,7 +355,6 @@ with gr.Blocks(theme=gr.themes.Origin(), css=css) as demo:
|
|
| 310 |
],
|
| 311 |
outputs=[result, expanded_prompt_output],
|
| 312 |
)
|
| 313 |
-
|
| 314 |
|
| 315 |
if __name__ == "__main__":
|
| 316 |
demo.launch()
|
|
|
|
| 2 |
subprocess.run(['sh', './spaces.sh'])
|
| 3 |
|
| 4 |
import os
|
| 5 |
+
|
| 6 |
os.environ['PYTORCH_NVML_BASED_CUDA_CHECK'] = '1'
|
| 7 |
os.environ['TORCH_LINALG_PREFER_CUSOLVER'] = '1'
|
| 8 |
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True,pinned_use_background_threads:True'
|
|
|
|
| 17 |
import threading
|
| 18 |
import io
|
| 19 |
|
|
|
|
| 20 |
from google.oauth2 import service_account
|
| 21 |
from google.cloud import storage
|
| 22 |
|
|
|
|
| 26 |
def install_flashattn():
|
| 27 |
subprocess.run(['sh', './flashattn.sh'])
|
| 28 |
|
| 29 |
+
install_flashattn()
|
| 30 |
|
|
|
|
| 31 |
torch.backends.cuda.matmul.allow_tf32 = False
|
| 32 |
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False
|
| 33 |
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
|
|
|
|
| 43 |
from image_gen_aux import UpscaleWithModel
|
| 44 |
|
| 45 |
|
|
|
|
|
|
|
| 46 |
GCS_BUCKET_NAME = os.getenv("GCS_BUCKET_NAME")
|
| 47 |
GCS_SA_KEY = os.getenv("GCS_SA_KEY") # The full JSON key content as a string
|
|
|
|
|
|
|
| 48 |
gcs_client = None
|
| 49 |
+
|
| 50 |
if GCS_SA_KEY and GCS_BUCKET_NAME:
|
| 51 |
try:
|
| 52 |
credentials_info = eval(GCS_SA_KEY) # Using eval is safe here if you trust the secret source
|
|
|
|
| 74 |
|
| 75 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 76 |
|
| 77 |
+
from diffusers.models.attention_processor import AttnProcessor2_0
|
| 78 |
+
from kernels import get_kernel
|
| 79 |
+
fa3_kernel = get_kernel("kernels-community/flash-attn3") # Or vllm-flash-attn3
|
| 80 |
+
class FlashAttentionProcessor(AttnProcessor2_0):
|
| 81 |
+
def __call__(
|
| 82 |
+
self,
|
| 83 |
+
attn,
|
| 84 |
+
hidden_states,
|
| 85 |
+
encoder_hidden_states=None, # This will be present for cross-attention
|
| 86 |
+
attention_mask=None,
|
| 87 |
+
temb=None, # This might be present in some attention mechanisms, pass through if not used directly
|
| 88 |
+
**kwargs,
|
| 89 |
+
):
|
| 90 |
+
# Determine if it's self-attention or cross-attention
|
| 91 |
+
# For self-attention, encoder_hidden_states is None or identical to hidden_states
|
| 92 |
+
is_cross_attention = encoder_hidden_states is not None and encoder_hidden_states.shape[1] != hidden_states.shape[1]
|
| 93 |
+
# SD3.5 uses DiT, where hidden_states are often 3D (B, Seq, Dim)
|
| 94 |
+
# However, attention can be within a transformer block which might internally reshape.
|
| 95 |
+
# Ensure your inputs (query, key, value) are properly shaped for the kernel.
|
| 96 |
+
# The kernel expects (Batch, Heads, Sequence, Dim_Head)
|
| 97 |
+
query = attn.to_q(hidden_states)
|
| 98 |
+
if is_cross_attention:
|
| 99 |
+
key = attn.to_k(encoder_hidden_states)
|
| 100 |
+
value = attn.to_v(encoder_hidden_states)
|
| 101 |
+
else: # Self-attention
|
| 102 |
+
key = attn.to_k(hidden_states)
|
| 103 |
+
value = attn.to_v(hidden_states)
|
| 104 |
+
scale = attn.scale
|
| 105 |
+
query = query * scale
|
| 106 |
+
b, t, c = query.shape # B=batch_size, T=sequence_length, C=embedding_dim
|
| 107 |
+
h = attn.heads
|
| 108 |
+
d = c // h # dim_per_head
|
| 109 |
+
# Reshape to (Batch, Heads, Sequence, Dim_Head) for Flash Attention kernel
|
| 110 |
+
q_reshaped = query.reshape(b, t, h, d).permute(0, 2, 1, 3)
|
| 111 |
+
k_reshaped = key.reshape(b, t, h, d).permute(0, 2, 1, 3)
|
| 112 |
+
v_reshaped = value.reshape(b, t, h, d).permute(0, 2, 1, 3)
|
| 113 |
+
out_reshaped = torch.empty_like(q_reshaped)
|
| 114 |
+
# Call the Flash Attention kernel
|
| 115 |
+
fa3_kernel.attention(q_reshaped, k_reshaped, v_reshaped, out_reshaped)
|
| 116 |
+
# Reshape output back to (Batch, Sequence, Heads * Dim_Head)
|
| 117 |
+
out = out_reshaped.permute(0, 2, 1, 3).reshape(b, t, c)
|
| 118 |
+
out = attn.to_out(out)
|
| 119 |
+
return out
|
| 120 |
+
|
| 121 |
@spaces.GPU(duration=120)
|
| 122 |
def compile_transformer():
|
| 123 |
with spaces.aoti_capture(pipe.transformer) as call:
|
|
|
|
| 145 |
|
| 146 |
pipe, upscaler_2 = load_model()
|
| 147 |
|
| 148 |
+
fa_processor = FlashAttentionProcessor()
|
| 149 |
+
|
| 150 |
+
for name, module in pipe.transformer.named_modules():
|
| 151 |
+
if isinstance(module, AttnProcessor2_0):
|
| 152 |
+
module.processor = fa_processor
|
| 153 |
+
|
| 154 |
compiled_transformer = compile_transformer()
|
| 155 |
spaces.aoti_apply(compiled_transformer, pipe.transformer)
|
| 156 |
|
| 157 |
MAX_SEED = np.iinfo(np.int32).max
|
| 158 |
MAX_IMAGE_SIZE = 4096
|
| 159 |
|
|
|
|
| 160 |
@spaces.GPU(duration=45)
|
| 161 |
def generate_images_30(prompt, neg_prompt_1, neg_prompt_2, neg_prompt_3, width, height, guidance, steps, progress=gr.Progress(track_tqdm=True)):
|
| 162 |
seed = random.randint(0, MAX_SEED)
|
|
|
|
| 278 |
#col-container {margin: 0 auto;max-width: 640px;}
|
| 279 |
body{background-color: blue;}
|
| 280 |
"""
|
| 281 |
+
|
| 282 |
with gr.Blocks(theme=gr.themes.Origin(), css=css) as demo:
|
| 283 |
with gr.Column(elem_id="col-container"):
|
| 284 |
gr.Markdown(" # StableDiffusion 3.5 Large with UltraReal lora test")
|
|
|
|
| 355 |
],
|
| 356 |
outputs=[result, expanded_prompt_output],
|
| 357 |
)
|
|
|
|
| 358 |
|
| 359 |
if __name__ == "__main__":
|
| 360 |
demo.launch()
|