ford442 commited on
Commit
324564e
·
verified ·
1 Parent(s): 9b2099a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -6
app.py CHANGED
@@ -191,17 +191,17 @@ def generate(
191
  guidance_scale = state["guidance_scale"]
192
  all_timesteps_cpu = state["all_timesteps"]
193
  timesteps_split_for_state = state["timesteps_split"]
194
- timesteps_chunk_np = state["timesteps_split"][segment - 1]
195
- segment_timesteps = torch.from_numpy(timesteps_chunk_np).to("cuda")
196
  seed = state["seed"]
197
  height = state["height"]
198
  width = state["width"]
199
  generator = torch.Generator(device='cuda').manual_seed(seed)
200
  pipe.scheduler.set_timesteps(len(all_timesteps_cpu), device=device)
201
  prompt_embeds = state["prompt_embeds"].to("cuda", dtype=torch.bfloat16)
202
- negative_prompt_embeds = state["negative_prompt_embeds"].to("cuda", dtype=torch.bfloat16) if state["negative_prompt_embeds"] is not None else None
203
  pooled_prompt_embeds = state["pooled_prompt_embeds"].to("cuda", dtype=torch.bfloat16)
204
- negative_pooled_prompt_embeds = state["negative_pooled_prompt_embeds"].to("cuda", dtype=torch.bfloat16) if state["negative_pooled_prompt_embeds"] is not None else None
205
  unet_prompt_embeds = prompt_embeds
206
  unet_added_text_embeds = pooled_prompt_embeds
207
  unet_prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
@@ -222,7 +222,7 @@ def generate(
222
  current_latents = pipe.scheduler.step(noise_pred, t, current_latents, generator=generator, return_dict=False)[0]
223
 
224
  intermediate_latents_cpu = current_latents.detach().cpu() # Latents after first half, moved to CPU
225
-
226
  # test with 2 segments
227
  if segment==2:
228
  final_latents = current_latents
@@ -243,7 +243,7 @@ def generate(
243
  original_negative_prompt_embeds_cpu = negative_prompt_embeds.cpu()
244
  original_pooled_prompt_embeds_cpu = pooled_prompt_embeds.cpu()
245
  original_negative_pooled_prompt_embeds_cpu = negative_pooled_prompt_embeds.cpu()
246
- original_add_time_ids_cpu = add_time_ids.cpu() # Saved before CFG duplication
247
  state = {
248
  "intermediate_latents": intermediate_latents_cpu,
249
  "all_timesteps": all_timesteps_cpu, # Save full list generated by scheduler
 
191
  guidance_scale = state["guidance_scale"]
192
  all_timesteps_cpu = state["all_timesteps"]
193
  timesteps_split_for_state = state["timesteps_split"]
194
+ #timesteps_chunk_np = state["timesteps_split"][segment - 1]
195
+ segment_timesteps = torch.from_numpy(timesteps_split_np[1]).to("cuda")
196
  seed = state["seed"]
197
  height = state["height"]
198
  width = state["width"]
199
  generator = torch.Generator(device='cuda').manual_seed(seed)
200
  pipe.scheduler.set_timesteps(len(all_timesteps_cpu), device=device)
201
  prompt_embeds = state["prompt_embeds"].to("cuda", dtype=torch.bfloat16)
202
+ negative_prompt_embeds = state["negative_prompt_embeds"].to("cuda", dtype=torch.bfloat16)
203
  pooled_prompt_embeds = state["pooled_prompt_embeds"].to("cuda", dtype=torch.bfloat16)
204
+ negative_pooled_prompt_embeds = state["negative_pooled_prompt_embeds"].to("cuda", dtype=torch.bfloat16) if state["negative_pooled_prompt_embeds"]
205
  unet_prompt_embeds = prompt_embeds
206
  unet_added_text_embeds = pooled_prompt_embeds
207
  unet_prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
 
222
  current_latents = pipe.scheduler.step(noise_pred, t, current_latents, generator=generator, return_dict=False)[0]
223
 
224
  intermediate_latents_cpu = current_latents.detach().cpu() # Latents after first half, moved to CPU
225
+
226
  # test with 2 segments
227
  if segment==2:
228
  final_latents = current_latents
 
243
  original_negative_prompt_embeds_cpu = negative_prompt_embeds.cpu()
244
  original_pooled_prompt_embeds_cpu = pooled_prompt_embeds.cpu()
245
  original_negative_pooled_prompt_embeds_cpu = negative_pooled_prompt_embeds.cpu()
246
+ original_add_time_ids_cpu = add_time_ids.cpu()
247
  state = {
248
  "intermediate_latents": intermediate_latents_cpu,
249
  "all_timesteps": all_timesteps_cpu, # Save full list generated by scheduler