ford442 commited on
Commit
5338d3c
·
verified ·
1 Parent(s): 81107b1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -12
app.py CHANGED
@@ -158,11 +158,9 @@ def generate(
158
  all_timesteps_cpu = timesteps.cpu()
159
  timesteps_split_np = np.array_split(all_timesteps_cpu.numpy(), 2)
160
  timesteps_split_for_state = [chunk for chunk in timesteps_split_np] # Store list of numpy arrays
161
-
162
  # test with 2 segments
163
  segment_timesteps = torch.from_numpy(timesteps_split_np[0]).to("cuda")
164
  #segment_timesteps = timesteps
165
-
166
  num_channels_latents = pipe.unet.config.in_channels
167
  latents = pipe.prepare_latents(
168
  batch_size=1, num_channels_latents=pipe.unet.config.in_channels, height=height, width=width,
@@ -189,21 +187,16 @@ def generate(
189
  state = torch.load(state_file, weights_only=False)
190
  latents = state["intermediate_latents"].to("cuda", dtype=torch.bfloat16)
191
  guidance_scale = state["guidance_scale"]
192
-
193
  all_timesteps_cpu = state["all_timesteps"]
194
- timesteps_split_np = np.array_split(all_timesteps_cpu.numpy(), 2)
195
- timesteps_split_for_state = [chunk for chunk in timesteps_split_np] # Store list of numpy arrays
 
196
  #timesteps_chunk_np = state["timesteps_split"][segment - 1]
197
- segment_timesteps = torch.from_numpy(state["timesteps_split"][segment - 1]).to("cuda")
198
-
199
  # test with 2 segments
200
  segment_timesteps = torch.from_numpy(timesteps_split_np[1]).to("cuda")
201
-
202
- seed = state["seed"]
203
- height = state["height"]
204
- width = state["width"]
205
  generator = torch.Generator(device='cuda').manual_seed(seed)
206
- pipe.scheduler.set_timesteps(len(all_timesteps_cpu), device=device)
207
  prompt_embeds = state["prompt_embeds"].to("cuda", dtype=torch.bfloat16)
208
  negative_prompt_embeds = state["negative_prompt_embeds"].to("cuda", dtype=torch.bfloat16)
209
  pooled_prompt_embeds = state["pooled_prompt_embeds"].to("cuda", dtype=torch.bfloat16)
@@ -215,6 +208,8 @@ def generate(
215
  add_time_ids = state["add_time_ids"].to("cuda", dtype=torch.bfloat16) # Original time IDs
216
  loop_add_time_ids = add_time_ids # Start with original loaded ones
217
  loop_add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
 
 
218
  added_cond_kwargs = {"text_embeds": unet_added_text_embeds, "time_ids": loop_add_time_ids}
219
  current_latents = latents # Start with loaded intermediate latents
220
 
 
158
  all_timesteps_cpu = timesteps.cpu()
159
  timesteps_split_np = np.array_split(all_timesteps_cpu.numpy(), 2)
160
  timesteps_split_for_state = [chunk for chunk in timesteps_split_np] # Store list of numpy arrays
 
161
  # test with 2 segments
162
  segment_timesteps = torch.from_numpy(timesteps_split_np[0]).to("cuda")
163
  #segment_timesteps = timesteps
 
164
  num_channels_latents = pipe.unet.config.in_channels
165
  latents = pipe.prepare_latents(
166
  batch_size=1, num_channels_latents=pipe.unet.config.in_channels, height=height, width=width,
 
187
  state = torch.load(state_file, weights_only=False)
188
  latents = state["intermediate_latents"].to("cuda", dtype=torch.bfloat16)
189
  guidance_scale = state["guidance_scale"]
190
+ seed = state["seed"]
191
  all_timesteps_cpu = state["all_timesteps"]
192
+ height = state["height"]
193
+ width = state["width"]
194
+ pipe.scheduler.set_timesteps(len(all_timesteps_cpu), device=device)
195
  #timesteps_chunk_np = state["timesteps_split"][segment - 1]
196
+ #segment_timesteps = torch.from_numpy(state["timesteps_split"][segment - 1]).to("cuda")
 
197
  # test with 2 segments
198
  segment_timesteps = torch.from_numpy(timesteps_split_np[1]).to("cuda")
 
 
 
 
199
  generator = torch.Generator(device='cuda').manual_seed(seed)
 
200
  prompt_embeds = state["prompt_embeds"].to("cuda", dtype=torch.bfloat16)
201
  negative_prompt_embeds = state["negative_prompt_embeds"].to("cuda", dtype=torch.bfloat16)
202
  pooled_prompt_embeds = state["pooled_prompt_embeds"].to("cuda", dtype=torch.bfloat16)
 
208
  add_time_ids = state["add_time_ids"].to("cuda", dtype=torch.bfloat16) # Original time IDs
209
  loop_add_time_ids = add_time_ids # Start with original loaded ones
210
  loop_add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
211
+ timesteps_split_np = np.array_split(all_timesteps_cpu.numpy(), 2)
212
+ timesteps_split_for_state = [chunk for chunk in timesteps_split_np] # Store list of numpy arrays
213
  added_cond_kwargs = {"text_embeds": unet_added_text_embeds, "time_ids": loop_add_time_ids}
214
  current_latents = latents # Start with loaded intermediate latents
215