Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
|
@@ -191,17 +191,17 @@ def generate(
|
|
| 191 |
guidance_scale = state["guidance_scale"]
|
| 192 |
all_timesteps_cpu = state["all_timesteps"]
|
| 193 |
timesteps_split_for_state = state["timesteps_split"]
|
| 194 |
-
timesteps_chunk_np = state["timesteps_split"][segment - 1]
|
| 195 |
-
segment_timesteps = torch.from_numpy(
|
| 196 |
seed = state["seed"]
|
| 197 |
height = state["height"]
|
| 198 |
width = state["width"]
|
| 199 |
generator = torch.Generator(device='cuda').manual_seed(seed)
|
| 200 |
pipe.scheduler.set_timesteps(len(all_timesteps_cpu), device=device)
|
| 201 |
prompt_embeds = state["prompt_embeds"].to("cuda", dtype=torch.bfloat16)
|
| 202 |
-
negative_prompt_embeds = state["negative_prompt_embeds"].to("cuda", dtype=torch.bfloat16)
|
| 203 |
pooled_prompt_embeds = state["pooled_prompt_embeds"].to("cuda", dtype=torch.bfloat16)
|
| 204 |
-
negative_pooled_prompt_embeds = state["negative_pooled_prompt_embeds"].to("cuda", dtype=torch.bfloat16) if state["negative_pooled_prompt_embeds"]
|
| 205 |
unet_prompt_embeds = prompt_embeds
|
| 206 |
unet_added_text_embeds = pooled_prompt_embeds
|
| 207 |
unet_prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
|
@@ -222,7 +222,7 @@ def generate(
|
|
| 222 |
current_latents = pipe.scheduler.step(noise_pred, t, current_latents, generator=generator, return_dict=False)[0]
|
| 223 |
|
| 224 |
intermediate_latents_cpu = current_latents.detach().cpu() # Latents after first half, moved to CPU
|
| 225 |
-
|
| 226 |
# test with 2 segments
|
| 227 |
if segment==2:
|
| 228 |
final_latents = current_latents
|
|
@@ -243,7 +243,7 @@ def generate(
|
|
| 243 |
original_negative_prompt_embeds_cpu = negative_prompt_embeds.cpu()
|
| 244 |
original_pooled_prompt_embeds_cpu = pooled_prompt_embeds.cpu()
|
| 245 |
original_negative_pooled_prompt_embeds_cpu = negative_pooled_prompt_embeds.cpu()
|
| 246 |
-
original_add_time_ids_cpu = add_time_ids.cpu()
|
| 247 |
state = {
|
| 248 |
"intermediate_latents": intermediate_latents_cpu,
|
| 249 |
"all_timesteps": all_timesteps_cpu, # Save full list generated by scheduler
|
|
|
|
| 191 |
guidance_scale = state["guidance_scale"]
|
| 192 |
all_timesteps_cpu = state["all_timesteps"]
|
| 193 |
timesteps_split_for_state = state["timesteps_split"]
|
| 194 |
+
#timesteps_chunk_np = state["timesteps_split"][segment - 1]
|
| 195 |
+
segment_timesteps = torch.from_numpy(timesteps_split_np[1]).to("cuda")
|
| 196 |
seed = state["seed"]
|
| 197 |
height = state["height"]
|
| 198 |
width = state["width"]
|
| 199 |
generator = torch.Generator(device='cuda').manual_seed(seed)
|
| 200 |
pipe.scheduler.set_timesteps(len(all_timesteps_cpu), device=device)
|
| 201 |
prompt_embeds = state["prompt_embeds"].to("cuda", dtype=torch.bfloat16)
|
| 202 |
+
negative_prompt_embeds = state["negative_prompt_embeds"].to("cuda", dtype=torch.bfloat16)
|
| 203 |
pooled_prompt_embeds = state["pooled_prompt_embeds"].to("cuda", dtype=torch.bfloat16)
|
| 204 |
+
negative_pooled_prompt_embeds = state["negative_pooled_prompt_embeds"].to("cuda", dtype=torch.bfloat16) if state["negative_pooled_prompt_embeds"]
|
| 205 |
unet_prompt_embeds = prompt_embeds
|
| 206 |
unet_added_text_embeds = pooled_prompt_embeds
|
| 207 |
unet_prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
|
|
|
| 222 |
current_latents = pipe.scheduler.step(noise_pred, t, current_latents, generator=generator, return_dict=False)[0]
|
| 223 |
|
| 224 |
intermediate_latents_cpu = current_latents.detach().cpu() # Latents after first half, moved to CPU
|
| 225 |
+
|
| 226 |
# test with 2 segments
|
| 227 |
if segment==2:
|
| 228 |
final_latents = current_latents
|
|
|
|
| 243 |
original_negative_prompt_embeds_cpu = negative_prompt_embeds.cpu()
|
| 244 |
original_pooled_prompt_embeds_cpu = pooled_prompt_embeds.cpu()
|
| 245 |
original_negative_pooled_prompt_embeds_cpu = negative_pooled_prompt_embeds.cpu()
|
| 246 |
+
original_add_time_ids_cpu = add_time_ids.cpu()
|
| 247 |
state = {
|
| 248 |
"intermediate_latents": intermediate_latents_cpu,
|
| 249 |
"all_timesteps": all_timesteps_cpu, # Save full list generated by scheduler
|