1inkusFace commited on
Commit
413e290
·
verified ·
1 Parent(s): 58cc351

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -10
app.py CHANGED
@@ -2,7 +2,7 @@ import subprocess
2
  subprocess.run(['sh', './spaces.sh'])
3
 
4
  import os
5
- # Environment variable setup
6
  os.environ['PYTORCH_NVML_BASED_CUDA_CHECK'] = '1'
7
  os.environ['TORCH_LINALG_PREFER_CUSOLVER'] = '1'
8
  os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True,pinned_use_background_threads:True'
@@ -17,7 +17,6 @@ import datetime
17
  import threading
18
  import io
19
 
20
- # --- New GCS Imports ---
21
  from google.oauth2 import service_account
22
  from google.cloud import storage
23
 
@@ -27,9 +26,8 @@ import torch
27
  def install_flashattn():
28
  subprocess.run(['sh', './flashattn.sh'])
29
 
30
- #install_flashattn()
31
 
32
- # Torch performance settings
33
  torch.backends.cuda.matmul.allow_tf32 = False
34
  torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False
35
  torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
@@ -45,13 +43,10 @@ from PIL import Image
45
  from image_gen_aux import UpscaleWithModel
46
 
47
 
48
- # --- GCS Configuration ---
49
- # Make sure to set these secrets in your Hugging Face Space settings
50
  GCS_BUCKET_NAME = os.getenv("GCS_BUCKET_NAME")
51
  GCS_SA_KEY = os.getenv("GCS_SA_KEY") # The full JSON key content as a string
52
-
53
- # Initialize GCS client if credentials are available
54
  gcs_client = None
 
55
  if GCS_SA_KEY and GCS_BUCKET_NAME:
56
  try:
57
  credentials_info = eval(GCS_SA_KEY) # Using eval is safe here if you trust the secret source
@@ -79,6 +74,50 @@ def upload_to_gcs(image_object, filename):
79
 
80
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  @spaces.GPU(duration=120)
83
  def compile_transformer():
84
  with spaces.aoti_capture(pipe.transformer) as call:
@@ -106,13 +145,18 @@ def load_model():
106
 
107
  pipe, upscaler_2 = load_model()
108
 
 
 
 
 
 
 
109
  compiled_transformer = compile_transformer()
110
  spaces.aoti_apply(compiled_transformer, pipe.transformer)
111
 
112
  MAX_SEED = np.iinfo(np.int32).max
113
  MAX_IMAGE_SIZE = 4096
114
 
115
-
116
  @spaces.GPU(duration=45)
117
  def generate_images_30(prompt, neg_prompt_1, neg_prompt_2, neg_prompt_3, width, height, guidance, steps, progress=gr.Progress(track_tqdm=True)):
118
  seed = random.randint(0, MAX_SEED)
@@ -234,6 +278,7 @@ css = """
234
  #col-container {margin: 0 auto;max-width: 640px;}
235
  body{background-color: blue;}
236
  """
 
237
  with gr.Blocks(theme=gr.themes.Origin(), css=css) as demo:
238
  with gr.Column(elem_id="col-container"):
239
  gr.Markdown(" # StableDiffusion 3.5 Large with UltraReal lora test")
@@ -310,7 +355,6 @@ with gr.Blocks(theme=gr.themes.Origin(), css=css) as demo:
310
  ],
311
  outputs=[result, expanded_prompt_output],
312
  )
313
-
314
 
315
  if __name__ == "__main__":
316
  demo.launch()
 
2
  subprocess.run(['sh', './spaces.sh'])
3
 
4
  import os
5
+
6
  os.environ['PYTORCH_NVML_BASED_CUDA_CHECK'] = '1'
7
  os.environ['TORCH_LINALG_PREFER_CUSOLVER'] = '1'
8
  os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True,pinned_use_background_threads:True'
 
17
  import threading
18
  import io
19
 
 
20
  from google.oauth2 import service_account
21
  from google.cloud import storage
22
 
 
26
  def install_flashattn():
27
  subprocess.run(['sh', './flashattn.sh'])
28
 
29
+ install_flashattn()
30
 
 
31
  torch.backends.cuda.matmul.allow_tf32 = False
32
  torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False
33
  torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
 
43
  from image_gen_aux import UpscaleWithModel
44
 
45
 
 
 
46
  GCS_BUCKET_NAME = os.getenv("GCS_BUCKET_NAME")
47
  GCS_SA_KEY = os.getenv("GCS_SA_KEY") # The full JSON key content as a string
 
 
48
  gcs_client = None
49
+
50
  if GCS_SA_KEY and GCS_BUCKET_NAME:
51
  try:
52
  credentials_info = eval(GCS_SA_KEY) # Using eval is safe here if you trust the secret source
 
74
 
75
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
76
 
77
+ from diffusers.models.attention_processor import AttnProcessor2_0
78
+ from kernels import get_kernel
79
+ fa3_kernel = get_kernel("kernels-community/flash-attn3") # Or vllm-flash-attn3
80
+ class FlashAttentionProcessor(AttnProcessor2_0):
81
+ def __call__(
82
+ self,
83
+ attn,
84
+ hidden_states,
85
+ encoder_hidden_states=None, # This will be present for cross-attention
86
+ attention_mask=None,
87
+ temb=None, # This might be present in some attention mechanisms, pass through if not used directly
88
+ **kwargs,
89
+ ):
90
+ # Determine if it's self-attention or cross-attention
91
+ # For self-attention, encoder_hidden_states is None or identical to hidden_states
92
+ is_cross_attention = encoder_hidden_states is not None and encoder_hidden_states.shape[1] != hidden_states.shape[1]
93
+ # SD3.5 uses DiT, where hidden_states are often 3D (B, Seq, Dim)
94
+ # However, attention can be within a transformer block which might internally reshape.
95
+ # Ensure your inputs (query, key, value) are properly shaped for the kernel.
96
+ # The kernel expects (Batch, Heads, Sequence, Dim_Head)
97
+ query = attn.to_q(hidden_states)
98
+ if is_cross_attention:
99
+ key = attn.to_k(encoder_hidden_states)
100
+ value = attn.to_v(encoder_hidden_states)
101
+ else: # Self-attention
102
+ key = attn.to_k(hidden_states)
103
+ value = attn.to_v(hidden_states)
104
+ scale = attn.scale
105
+ query = query * scale
106
+ b, t, c = query.shape # B=batch_size, T=sequence_length, C=embedding_dim
107
+ h = attn.heads
108
+ d = c // h # dim_per_head
109
+ # Reshape to (Batch, Heads, Sequence, Dim_Head) for Flash Attention kernel
110
+ q_reshaped = query.reshape(b, t, h, d).permute(0, 2, 1, 3)
111
+ k_reshaped = key.reshape(b, t, h, d).permute(0, 2, 1, 3)
112
+ v_reshaped = value.reshape(b, t, h, d).permute(0, 2, 1, 3)
113
+ out_reshaped = torch.empty_like(q_reshaped)
114
+ # Call the Flash Attention kernel
115
+ fa3_kernel.attention(q_reshaped, k_reshaped, v_reshaped, out_reshaped)
116
+ # Reshape output back to (Batch, Sequence, Heads * Dim_Head)
117
+ out = out_reshaped.permute(0, 2, 1, 3).reshape(b, t, c)
118
+ out = attn.to_out(out)
119
+ return out
120
+
121
  @spaces.GPU(duration=120)
122
  def compile_transformer():
123
  with spaces.aoti_capture(pipe.transformer) as call:
 
145
 
146
  pipe, upscaler_2 = load_model()
147
 
148
+ fa_processor = FlashAttentionProcessor()
149
+
150
+ for name, module in pipe.transformer.named_modules():
151
+ if isinstance(module, AttnProcessor2_0):
152
+ module.processor = fa_processor
153
+
154
  compiled_transformer = compile_transformer()
155
  spaces.aoti_apply(compiled_transformer, pipe.transformer)
156
 
157
  MAX_SEED = np.iinfo(np.int32).max
158
  MAX_IMAGE_SIZE = 4096
159
 
 
160
  @spaces.GPU(duration=45)
161
  def generate_images_30(prompt, neg_prompt_1, neg_prompt_2, neg_prompt_3, width, height, guidance, steps, progress=gr.Progress(track_tqdm=True)):
162
  seed = random.randint(0, MAX_SEED)
 
278
  #col-container {margin: 0 auto;max-width: 640px;}
279
  body{background-color: blue;}
280
  """
281
+
282
  with gr.Blocks(theme=gr.themes.Origin(), css=css) as demo:
283
  with gr.Column(elem_id="col-container"):
284
  gr.Markdown(" # StableDiffusion 3.5 Large with UltraReal lora test")
 
355
  ],
356
  outputs=[result, expanded_prompt_output],
357
  )
 
358
 
359
  if __name__ == "__main__":
360
  demo.launch()