ford442 commited on
Commit
d3b3856
·
verified ·
1 Parent(s): adb6539

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -9
app.py CHANGED
@@ -158,8 +158,11 @@ def generate(
158
  all_timesteps_cpu = timesteps.cpu()
159
  timesteps_split_np = np.array_split(all_timesteps_cpu.numpy(), 8)
160
  timesteps_split_for_state = [chunk for chunk in timesteps_split_np] # Store list of numpy arrays
161
- split_point = num_inference_steps // 8
162
- segment_timesteps = torch.from_numpy(timesteps_split_np[0]).to("cuda")
 
 
 
163
  num_channels_latents = pipe.unet.config.in_channels
164
  latents = pipe.prepare_latents(
165
  batch_size=1, num_channels_latents=pipe.unet.config.in_channels, height=height, width=width,
@@ -209,8 +212,8 @@ def generate(
209
  added_cond_kwargs = {"text_embeds": unet_added_text_embeds, "time_ids": loop_add_time_ids}
210
  current_latents = latents # Start with loaded intermediate latents
211
 
212
- for i, t in enumerate(pipe.progress_bar(segment_timesteps)): # Only first half timesteps
213
- latent_model_input = torch.cat([current_latents] * 2) if guidance_scale > 1.0 else current_latents
214
  latent_model_input = pipe.scheduler.scale_model_input(latent_model_input, t)
215
  with torch.no_grad():
216
  noise_pred = pipe.unet(latent_model_input, t, encoder_hidden_states=unet_prompt_embeds,added_cond_kwargs=added_cond_kwargs, return_dict=False)[0]
@@ -219,15 +222,16 @@ def generate(
219
  current_latents = pipe.scheduler.step(noise_pred, t, current_latents, generator=generator, return_dict=False)[0]
220
 
221
  intermediate_latents_cpu = current_latents.detach().cpu() # Latents after first half, moved to CPU
222
-
223
- if segment==8:
224
- final_latents = current_latents # Latents after the second half loop
 
225
  final_latents = final_latents / pipe.vae.config.scaling_factor
226
  with torch.no_grad():
227
- image = pipe.vae.decode(final_latents, return_dict=False)[0] # VAE might prefer fp16/fp32
228
  image = pipe.image_processor.postprocess(image, output_type="pil")[0]
229
  output_image_file = f"rv_L_{seed}.png"
230
- image.save(output_image_file) # Use output_image_file defined earlier
231
  #timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
232
  #upload_to_ftp(filename)
233
  #uploadNote(prompt,num_inference_steps,guidance_scale,timestamp)
 
158
  all_timesteps_cpu = timesteps.cpu()
159
  timesteps_split_np = np.array_split(all_timesteps_cpu.numpy(), 8)
160
  timesteps_split_for_state = [chunk for chunk in timesteps_split_np] # Store list of numpy arrays
161
+
162
+ # test with one segment
163
+ #segment_timesteps = torch.from_numpy(timesteps_split_np[0]).to("cuda")
164
+ segment_timesteps = timesteps
165
+
166
  num_channels_latents = pipe.unet.config.in_channels
167
  latents = pipe.prepare_latents(
168
  batch_size=1, num_channels_latents=pipe.unet.config.in_channels, height=height, width=width,
 
212
  added_cond_kwargs = {"text_embeds": unet_added_text_embeds, "time_ids": loop_add_time_ids}
213
  current_latents = latents # Start with loaded intermediate latents
214
 
215
+ for i, t in enumerate(pipe.progress_bar(segment_timesteps)):
216
+ latent_model_input = torch.cat([current_latents] * 2)
217
  latent_model_input = pipe.scheduler.scale_model_input(latent_model_input, t)
218
  with torch.no_grad():
219
  noise_pred = pipe.unet(latent_model_input, t, encoder_hidden_states=unet_prompt_embeds,added_cond_kwargs=added_cond_kwargs, return_dict=False)[0]
 
222
  current_latents = pipe.scheduler.step(noise_pred, t, current_latents, generator=generator, return_dict=False)[0]
223
 
224
  intermediate_latents_cpu = current_latents.detach().cpu() # Latents after first half, moved to CPU
225
+
226
+ # test with one segment
227
+ if segment==1:
228
+ final_latents = current_latents
229
  final_latents = final_latents / pipe.vae.config.scaling_factor
230
  with torch.no_grad():
231
+ image = pipe.vae.decode(final_latents, return_dict=False)[0]
232
  image = pipe.image_processor.postprocess(image, output_type="pil")[0]
233
  output_image_file = f"rv_L_{seed}.png"
234
+ image.save(output_image_file)
235
  #timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
236
  #upload_to_ftp(filename)
237
  #uploadNote(prompt,num_inference_steps,guidance_scale,timestamp)