Lotus-2_Depth / pipeline.py
haodongli's picture
init
436f5aa
raw
history blame
9.02 kB
from typing import Union, Optional, List, Dict, Any
import numpy as np
import torch
from diffusers import FluxPipeline
from diffusers.pipelines.flux import FluxPipelineOutput
from diffusers.pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps
from diffusers.utils import is_torch_xla_available
from utils.image_utils import resize_image, resize_image_first
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
class Lotus2Pipeline(FluxPipeline):
@torch.no_grad()
def __call__(
self,
rgb_in: Optional[torch.FloatTensor] = None,
prompt: Union[str, List[str]] = None,
num_inference_steps: int = 10,
output_type: Optional[str] = "pil",
process_res: Optional[int] = None,
timestep_core_predictor: int = 1,
guidance_scale: float = 3.5,
return_dict: bool = True,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
rgb_in (`torch.FloatTensor`, *optional*):
The input image to be used for generation.
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the prediction. Default is ''.
num_inference_steps (`int`, *optional*, defaults to 10):
The number of denoising steps. More denoising steps usually lead to a sharper prediction at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 7.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
Examples:
Returns:
[`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
images.
"""
# 1. prepare
batch_size = rgb_in.shape[0]
input_size = rgb_in.shape[2:]
rgb_in = resize_image_first(rgb_in, process_res)
height, width = rgb_in.shape[2:]
self._guidance_scale = guidance_scale
self._joint_attention_kwargs = joint_attention_kwargs
self._interrupt = False
device = self._execution_device
# 2. encode prompt
(
prompt_embeds,
pooled_prompt_embeds,
text_ids,
) = self.encode_prompt(
prompt=prompt,
prompt_2=None,
device=device,
)
# 3. prepare latent variables
rgb_in = rgb_in.to(device=device, dtype=self.dtype)
rgb_latents = self.vae.encode(rgb_in).latent_dist.sample()
rgb_latents = (rgb_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
packed_rgb_latents = self._pack_latents(
rgb_latents,
batch_size=rgb_latents.shape[0],
num_channels_latents=rgb_latents.shape[1],
height=rgb_latents.shape[2],
width=rgb_latents.shape[3],
)
latent_image_ids_core_predictor = self._prepare_latent_image_ids(batch_size, rgb_latents.shape[2]//2, rgb_latents.shape[3]//2, device, rgb_latents.dtype)
latent_image_ids = self._prepare_latent_image_ids(batch_size, rgb_latents.shape[2]//2, rgb_latents.shape[3]//2, device, rgb_latents.dtype)
# 4. prepare timesteps
timestep_core_predictor = torch.tensor(timestep_core_predictor).expand(batch_size).to(device=rgb_in.device, dtype=rgb_in.dtype)
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
image_seq_len = packed_rgb_latents.shape[1]
mu = calculate_shift(
image_seq_len,
self.scheduler.config.base_image_seq_len,
self.scheduler.config.max_image_seq_len,
self.scheduler.config.base_shift,
self.scheduler.config.max_shift,
)
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
device,
sigmas=sigmas,
mu=mu,
)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) # 0
self._num_timesteps = len(timesteps)
# 5. handle guidance
if self.transformer.config.guidance_embeds:
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
guidance = guidance.expand(packed_rgb_latents.shape[0])
else:
guidance = None
if self.joint_attention_kwargs is None:
self._joint_attention_kwargs = {}
# 6. core predictor
self.transformer.set_adapter("core_predictor")
latents = self.transformer(
hidden_states=packed_rgb_latents,
timestep=timestep_core_predictor / 1000,
guidance=guidance,
pooled_projections=pooled_prompt_embeds,
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids,
img_ids=latent_image_ids_core_predictor,
joint_attention_kwargs=self.joint_attention_kwargs, # {}
return_dict=False,
)[0]
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
latents = self.local_continuity_module(latents)
# 7. Denoising loop for detail sharpener
self.transformer.set_adapter("detail_sharpener")
latents = self._pack_latents(
latents,
batch_size=latents.shape[0],
num_channels_latents=latents.shape[1],
height=latents.shape[2],
width=latents.shape[3],
)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)
noise_pred = self.transformer(
hidden_states=latents,
timestep=timestep / 1000,
guidance=guidance,
pooled_projections=pooled_prompt_embeds,
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids,
img_ids=latent_image_ids,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]
# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents = latents.to(latents_dtype)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
latents = latents.to(dtype=self.dtype)
if output_type == "latent":
image = latents
else:
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
image = self.vae.decode(latents, return_dict=False)[0]
image = self.image_processor.postprocess(image, output_type=output_type)
# Resize output image to match input size
image = resize_image(image, input_size)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image,)
return FluxPipelineOutput(images=image)