Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
|
@@ -158,11 +158,9 @@ def generate(
|
|
| 158 |
all_timesteps_cpu = timesteps.cpu()
|
| 159 |
timesteps_split_np = np.array_split(all_timesteps_cpu.numpy(), 2)
|
| 160 |
timesteps_split_for_state = [chunk for chunk in timesteps_split_np] # Store list of numpy arrays
|
| 161 |
-
|
| 162 |
# test with 2 segments
|
| 163 |
segment_timesteps = torch.from_numpy(timesteps_split_np[0]).to("cuda")
|
| 164 |
#segment_timesteps = timesteps
|
| 165 |
-
|
| 166 |
num_channels_latents = pipe.unet.config.in_channels
|
| 167 |
latents = pipe.prepare_latents(
|
| 168 |
batch_size=1, num_channels_latents=pipe.unet.config.in_channels, height=height, width=width,
|
|
@@ -189,21 +187,16 @@ def generate(
|
|
| 189 |
state = torch.load(state_file, weights_only=False)
|
| 190 |
latents = state["intermediate_latents"].to("cuda", dtype=torch.bfloat16)
|
| 191 |
guidance_scale = state["guidance_scale"]
|
| 192 |
-
|
| 193 |
all_timesteps_cpu = state["all_timesteps"]
|
| 194 |
-
|
| 195 |
-
|
|
|
|
| 196 |
#timesteps_chunk_np = state["timesteps_split"][segment - 1]
|
| 197 |
-
segment_timesteps = torch.from_numpy(state["timesteps_split"][segment - 1]).to("cuda")
|
| 198 |
-
|
| 199 |
# test with 2 segments
|
| 200 |
segment_timesteps = torch.from_numpy(timesteps_split_np[1]).to("cuda")
|
| 201 |
-
|
| 202 |
-
seed = state["seed"]
|
| 203 |
-
height = state["height"]
|
| 204 |
-
width = state["width"]
|
| 205 |
generator = torch.Generator(device='cuda').manual_seed(seed)
|
| 206 |
-
pipe.scheduler.set_timesteps(len(all_timesteps_cpu), device=device)
|
| 207 |
prompt_embeds = state["prompt_embeds"].to("cuda", dtype=torch.bfloat16)
|
| 208 |
negative_prompt_embeds = state["negative_prompt_embeds"].to("cuda", dtype=torch.bfloat16)
|
| 209 |
pooled_prompt_embeds = state["pooled_prompt_embeds"].to("cuda", dtype=torch.bfloat16)
|
|
@@ -215,6 +208,8 @@ def generate(
|
|
| 215 |
add_time_ids = state["add_time_ids"].to("cuda", dtype=torch.bfloat16) # Original time IDs
|
| 216 |
loop_add_time_ids = add_time_ids # Start with original loaded ones
|
| 217 |
loop_add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
|
|
|
|
|
|
|
| 218 |
added_cond_kwargs = {"text_embeds": unet_added_text_embeds, "time_ids": loop_add_time_ids}
|
| 219 |
current_latents = latents # Start with loaded intermediate latents
|
| 220 |
|
|
|
|
| 158 |
all_timesteps_cpu = timesteps.cpu()
|
| 159 |
timesteps_split_np = np.array_split(all_timesteps_cpu.numpy(), 2)
|
| 160 |
timesteps_split_for_state = [chunk for chunk in timesteps_split_np] # Store list of numpy arrays
|
|
|
|
| 161 |
# test with 2 segments
|
| 162 |
segment_timesteps = torch.from_numpy(timesteps_split_np[0]).to("cuda")
|
| 163 |
#segment_timesteps = timesteps
|
|
|
|
| 164 |
num_channels_latents = pipe.unet.config.in_channels
|
| 165 |
latents = pipe.prepare_latents(
|
| 166 |
batch_size=1, num_channels_latents=pipe.unet.config.in_channels, height=height, width=width,
|
|
|
|
| 187 |
state = torch.load(state_file, weights_only=False)
|
| 188 |
latents = state["intermediate_latents"].to("cuda", dtype=torch.bfloat16)
|
| 189 |
guidance_scale = state["guidance_scale"]
|
| 190 |
+
seed = state["seed"]
|
| 191 |
all_timesteps_cpu = state["all_timesteps"]
|
| 192 |
+
height = state["height"]
|
| 193 |
+
width = state["width"]
|
| 194 |
+
pipe.scheduler.set_timesteps(len(all_timesteps_cpu), device=device)
|
| 195 |
#timesteps_chunk_np = state["timesteps_split"][segment - 1]
|
| 196 |
+
#segment_timesteps = torch.from_numpy(state["timesteps_split"][segment - 1]).to("cuda")
|
|
|
|
| 197 |
# test with 2 segments
|
| 198 |
segment_timesteps = torch.from_numpy(timesteps_split_np[1]).to("cuda")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 199 |
generator = torch.Generator(device='cuda').manual_seed(seed)
|
|
|
|
| 200 |
prompt_embeds = state["prompt_embeds"].to("cuda", dtype=torch.bfloat16)
|
| 201 |
negative_prompt_embeds = state["negative_prompt_embeds"].to("cuda", dtype=torch.bfloat16)
|
| 202 |
pooled_prompt_embeds = state["pooled_prompt_embeds"].to("cuda", dtype=torch.bfloat16)
|
|
|
|
| 208 |
add_time_ids = state["add_time_ids"].to("cuda", dtype=torch.bfloat16) # Original time IDs
|
| 209 |
loop_add_time_ids = add_time_ids # Start with original loaded ones
|
| 210 |
loop_add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
|
| 211 |
+
timesteps_split_np = np.array_split(all_timesteps_cpu.numpy(), 2)
|
| 212 |
+
timesteps_split_for_state = [chunk for chunk in timesteps_split_np] # Store list of numpy arrays
|
| 213 |
added_cond_kwargs = {"text_embeds": unet_added_text_embeds, "time_ids": loop_add_time_ids}
|
| 214 |
current_latents = latents # Start with loaded intermediate latents
|
| 215 |
|