Update handler_LAST_WORKING.py
Browse files- handler_LAST_WORKING.py +506 -344
handler_LAST_WORKING.py
CHANGED
|
@@ -1,24 +1,23 @@
|
|
| 1 |
from dataclasses import dataclass
|
| 2 |
from pathlib import Path
|
| 3 |
-
import pathlib
|
| 4 |
-
from typing import Dict, Any, Optional, Tuple
|
| 5 |
-
import asyncio
|
| 6 |
-
import base64
|
| 7 |
-
import io
|
| 8 |
-
import pprint
|
| 9 |
import logging
|
|
|
|
| 10 |
import random
|
| 11 |
-
import
|
| 12 |
import os
|
| 13 |
import numpy as np
|
| 14 |
import torch
|
| 15 |
-
import
|
| 16 |
-
|
| 17 |
-
from
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
from
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
from varnish import Varnish
|
| 24 |
from varnish.utils import is_truthy, process_input_image
|
|
@@ -27,14 +26,13 @@ from varnish.utils import is_truthy, process_input_image
|
|
| 27 |
logging.basicConfig(level=logging.INFO)
|
| 28 |
logger = logging.getLogger(__name__)
|
| 29 |
|
| 30 |
-
|
| 31 |
# Get token from environment
|
| 32 |
hf_token = os.getenv("HF_API_TOKEN")
|
| 33 |
|
| 34 |
# Constraints
|
| 35 |
MAX_LARGE_SIDE = 1280
|
| 36 |
-
MAX_SMALL_SIDE = 768
|
| 37 |
-
MAX_FRAMES =
|
| 38 |
|
| 39 |
# Check environment variable for pipeline support
|
| 40 |
support_image_prompt = is_truthy(os.getenv("SUPPORT_INPUT_IMAGE_PROMPT"))
|
|
@@ -48,10 +46,8 @@ class GenerationConfig:
|
|
| 48 |
negative_prompt: str = "saturated, highlight, overexposed, highlighted, overlit, shaking, too bright, worst quality, inconsistent motion, blurry, jittery, distorted, cropped, watermarked, watermark, logo, subtitle, subtitles, lowres"
|
| 49 |
|
| 50 |
# video model settings (will be used during generation of the initial raw video clip)
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
height: int = 416
|
| 54 |
-
|
| 55 |
|
| 56 |
# this is a hack to fool LTX-Video into believing our input image is an actual video frame with poor encoding quality
|
| 57 |
# after a quick benchmark using the value 70 seems like a sweet spot
|
|
@@ -62,8 +58,8 @@ class GenerationConfig:
|
|
| 62 |
# visual glitches appear after about 169 frames, so we don't need more actually
|
| 63 |
num_frames: int = (8 * 14) + 1
|
| 64 |
|
| 65 |
-
#
|
| 66 |
-
guidance_scale: float =
|
| 67 |
|
| 68 |
num_inference_steps: int = 8
|
| 69 |
|
|
@@ -71,16 +67,16 @@ class GenerationConfig:
|
|
| 71 |
seed: int = -1 # -1 means random seed
|
| 72 |
|
| 73 |
# varnish settings (will be used for post-processing after the raw video clip has been generated
|
| 74 |
-
fps: int = 30
|
| 75 |
-
double_num_frames: bool = False
|
| 76 |
-
super_resolution: bool = False
|
| 77 |
|
| 78 |
-
grain_amount: float = 0.0
|
| 79 |
|
| 80 |
# audio settings
|
| 81 |
enable_audio: bool = False # Whether to generate audio
|
| 82 |
audio_prompt: str = "" # Text prompt for audio generation
|
| 83 |
-
audio_negative_prompt: str = "voices, voice, talking, speaking, speech"
|
| 84 |
|
| 85 |
# The range of the CRF scale is 0–51, where:
|
| 86 |
# 0 is lossless (for 8 bit only, for 10 bit use -qp 0)
|
|
@@ -92,18 +88,26 @@ class GenerationConfig:
|
|
| 92 |
# The range is exponential, so increasing the CRF value +6 results in roughly half the bitrate / file size, while -6 leads to roughly twice the bitrate.
|
| 93 |
quality: int = 18
|
| 94 |
|
| 95 |
-
#
|
| 96 |
-
|
| 97 |
-
|
|
|
|
| 98 |
|
| 99 |
-
#
|
| 100 |
-
|
| 101 |
-
|
| 102 |
|
| 103 |
-
#
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
|
| 108 |
def validate_and_adjust(self) -> 'GenerationConfig':
|
| 109 |
"""Validate and adjust parameters to meet constraints"""
|
|
@@ -111,7 +115,7 @@ class GenerationConfig:
|
|
| 111 |
if not ((self.width == MAX_LARGE_SIDE and self.height == MAX_SMALL_SIDE) or
|
| 112 |
(self.width == MAX_SMALL_SIDE and self.height == MAX_LARGE_SIDE)):
|
| 113 |
# For other resolutions, ensure total pixels don't exceed max
|
| 114 |
-
MAX_TOTAL_PIXELS = MAX_SMALL_SIDE * MAX_LARGE_SIDE
|
| 115 |
|
| 116 |
# If total pixels exceed maximum, scale down proportionally
|
| 117 |
total_pixels = self.width * self.height
|
|
@@ -131,371 +135,527 @@ class GenerationConfig:
|
|
| 131 |
# Set random seed if not specified
|
| 132 |
if self.seed == -1:
|
| 133 |
self.seed = random.randint(0, 2**32 - 1)
|
| 134 |
-
|
| 135 |
-
return self
|
| 136 |
|
| 137 |
-
|
| 138 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
|
| 140 |
-
|
| 141 |
-
"""Initialize the handler with LTX models and Varnish
|
| 142 |
-
|
| 143 |
-
Args:
|
| 144 |
-
model_path: Path to LTX model weights
|
| 145 |
-
"""
|
| 146 |
-
print("EndpointHandler.__init__(): initializing..")
|
| 147 |
-
# Enable TF32 for potential speedup on Ampere GPUs
|
| 148 |
-
#torch.backends.cuda.matmul.allow_tf32 = True
|
| 149 |
-
|
| 150 |
-
# use distilled weights
|
| 151 |
-
model_path = Path("/repository/ltxv-2b-0.9.6-distilled-04-25.safetensors")
|
| 152 |
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
)
|
| 157 |
|
| 158 |
-
|
| 159 |
-
vae = AutoencoderKLLTXVideo.from_single_file(model_path, torch_dtype=torch.bfloat16)
|
| 160 |
-
|
| 161 |
-
if support_image_prompt:
|
| 162 |
-
print("EndpointHandler.__init__(): initializing LTXImageToVideoPipeline..")
|
| 163 |
-
self.image_to_video = LTXImageToVideoPipeline.from_pretrained(
|
| 164 |
-
"/repository",
|
| 165 |
-
transformer=transformer,
|
| 166 |
-
vae=vae,
|
| 167 |
-
torch_dtype=torch.bfloat16
|
| 168 |
-
).to("cuda")
|
| 169 |
-
|
| 170 |
-
#apply_teacache(self.image_to_video)
|
| 171 |
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
#self.image_to_video = torch.compile(self.image_to_video, mode="reduce-overhead", fullgraph=True)
|
| 178 |
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 188 |
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 214 |
# Initialize Varnish for post-processing
|
| 215 |
self.varnish = Varnish(
|
| 216 |
device="cuda",
|
| 217 |
model_base_dir="/repository/varnish",
|
| 218 |
-
|
| 219 |
-
# there is currently a bug with MMAudio and/or torch and/or the weight format and/or version..
|
| 220 |
-
# not sure how to fix that.. :/
|
| 221 |
-
#
|
| 222 |
-
# it says:
|
| 223 |
-
# File "dist-packages/varnish.py", line 152, in __init__
|
| 224 |
-
# self._setup_mmaudio()
|
| 225 |
-
# File "dist-packages/varnish/varnish.py", line 165, in _setup_mmaudio
|
| 226 |
-
# net.load_weights(torch.load(model.model_path, map_location=self.device, weights_only=False))
|
| 227 |
-
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 228 |
-
# File "dist-packages/torch/serialization.py", line 1384, in load
|
| 229 |
-
# return _legacy_load(
|
| 230 |
-
# ^^^^^^^^^^^^^
|
| 231 |
-
# File "dist-packages/torch/serialization.py", line 1628, in _legacy_load
|
| 232 |
-
# magic_number = pickle_module.load(f, **pickle_load_args)
|
| 233 |
-
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 234 |
-
# _pickle.UnpicklingError: invalid load key, '<'.
|
| 235 |
-
enable_mmaudio=False,
|
| 236 |
)
|
| 237 |
-
|
| 238 |
-
# Determine if TeaCache is already installed or not
|
| 239 |
-
self.text_to_video_teacache = False
|
| 240 |
-
self.image_to_video_teacache = False
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
async def process_frames(
|
| 244 |
-
self,
|
| 245 |
-
frames: torch.Tensor,
|
| 246 |
-
config: GenerationConfig
|
| 247 |
-
) -> tuple[str, dict]:
|
| 248 |
-
"""Post-process generated frames using Varnish
|
| 249 |
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
""
|
|
|
|
|
|
|
|
|
|
| 257 |
try:
|
| 258 |
-
#
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
|
| 273 |
-
#
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
"height": result.metadata.height,
|
| 277 |
-
"num_frames": result.metadata.frame_count,
|
| 278 |
-
"fps": result.metadata.fps,
|
| 279 |
-
"duration": result.metadata.duration,
|
| 280 |
-
"seed": config.seed,
|
| 281 |
-
}
|
| 282 |
|
| 283 |
-
|
| 284 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 285 |
except Exception as e:
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
|
|
|
| 290 |
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
| 291 |
-
"""Process
|
| 292 |
|
| 293 |
Args:
|
| 294 |
-
data: Request data containing
|
| 295 |
-
|
| 296 |
-
- parameters (dict):
|
| 297 |
-
- prompt (required, string): list of concepts to keep in the video.
|
| 298 |
-
- negative_prompt (optional, string): list of concepts to ignore in the video.
|
| 299 |
-
- width (optional, int, default to 768): width, or horizontal size in pixels.
|
| 300 |
-
- height (optional, int, default to 512): height, or vertical size in pixels.
|
| 301 |
-
- input_image_quality (optional, int, default to 100): this is a trick we use to convert a "pristine" image into a "dirty" video frame. This helps fooling LTX-Video into turning the image into an animated one.
|
| 302 |
-
- num_frames (optional, int, default to 129): the numer of frames must be a multiple of 8, plus 1 frame.
|
| 303 |
-
- guidance_scale (optional, float, default to 3.5): Guidance scale (values between 3.0 and 4.0 are nice)
|
| 304 |
-
- num_inference_steps (optional, int, default to 50): number of inference steps
|
| 305 |
-
- seed (optional, int, default to -1): set a random number generator seed, -1 means random seed.
|
| 306 |
-
- fps (optional, int, default to 24): FPS of the final video (eg. 24, 25, 30, 60)
|
| 307 |
-
- double_num_frames (optional, bool): if enabled, the number of frames will be multiplied by 2 using RIFE
|
| 308 |
-
- super_resolution (optional, bool): if enabled, the resolution will be multiplied by 2 using Real_ESRGAN
|
| 309 |
-
- grain_amount (optional, float): amount of film grain to add to the output video
|
| 310 |
-
- enable_audio (optional, bool): automatically generate an audio track
|
| 311 |
-
- audio_prompt (optional, str): prompt to use for the audio generation (concepts to add)
|
| 312 |
-
- audio_negative_prompt (optional, str): nehative prompt to use for the audio generation (concepts to ignore)
|
| 313 |
-
- quality (optional, str, default to 18): The range of the CRF scale is 0–51, where 0 is lossless (for 8 bit only, for 10 bit use -qp 0), 23 is the default, and 51 is worst quality possible.
|
| 314 |
-
- enable_teacache (optional, bool, default to False): Generate faster at the cost of a slight quality loss
|
| 315 |
-
- teacache_threshold (optional, float, default to 0.05): Amount of cache, 0 (original), 0.03 (1.6x speedup), 0.05 (Default, 2.1x speedup).
|
| 316 |
-
- enable_enhance_a_video (optional, bool, default to False): enable the enhance_a_video optimization
|
| 317 |
-
- enhance_a_video_weight(optional, float, default to 5.0): amount of video enhancement to apply
|
| 318 |
-
- lora_model_name(optional, str, default to ""): HuggingFace repo ID or path to LoRA model
|
| 319 |
-
- lora_model_weight_file(optional, str, default to ""): Specific weight file to load from the LoRA model
|
| 320 |
-
- lora_model_trigger(optional, str, default to ""): Optional trigger word to prepend to the prompt
|
| 321 |
Returns:
|
| 322 |
-
Dictionary
|
| 323 |
-
- video: Base64 encoded MP4 data URI
|
| 324 |
-
- content-type: MIME type
|
| 325 |
-
- metadata: Generation metadata
|
| 326 |
"""
|
| 327 |
-
|
| 328 |
-
|
| 329 |
|
| 330 |
-
|
| 331 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 332 |
|
| 333 |
-
params = data.get("parameters",
|
| 334 |
-
|
| 335 |
-
if not
|
| 336 |
raise ValueError("Either prompt or image must be provided")
|
| 337 |
-
|
| 338 |
-
#logger.debug(f"Raw parameters:")
|
| 339 |
-
# pprint.pprint(params)
|
| 340 |
-
|
| 341 |
# Create and validate configuration
|
| 342 |
config = GenerationConfig(
|
| 343 |
# general content settings
|
| 344 |
prompt=input_prompt,
|
| 345 |
negative_prompt=params.get("negative_prompt", GenerationConfig.negative_prompt),
|
| 346 |
-
|
| 347 |
-
# video model settings
|
| 348 |
width=params.get("width", GenerationConfig.width),
|
| 349 |
height=params.get("height", GenerationConfig.height),
|
| 350 |
input_image_quality=params.get("input_image_quality", GenerationConfig.input_image_quality),
|
| 351 |
num_frames=params.get("num_frames", GenerationConfig.num_frames),
|
| 352 |
guidance_scale=params.get("guidance_scale", GenerationConfig.guidance_scale),
|
| 353 |
num_inference_steps=params.get("num_inference_steps", GenerationConfig.num_inference_steps),
|
| 354 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 355 |
# reproducible generation settings
|
| 356 |
seed=params.get("seed", GenerationConfig.seed),
|
| 357 |
|
| 358 |
-
# varnish settings
|
| 359 |
-
fps=params.get("fps", GenerationConfig.fps),
|
| 360 |
-
double_num_frames=params.get("double_num_frames", GenerationConfig.double_num_frames),
|
| 361 |
-
super_resolution=params.get("super_resolution", GenerationConfig.super_resolution),
|
| 362 |
grain_amount=params.get("grain_amount", GenerationConfig.grain_amount),
|
| 363 |
enable_audio=params.get("enable_audio", GenerationConfig.enable_audio),
|
| 364 |
audio_prompt=params.get("audio_prompt", GenerationConfig.audio_prompt),
|
| 365 |
audio_negative_prompt=params.get("audio_negative_prompt", GenerationConfig.audio_negative_prompt),
|
| 366 |
quality=params.get("quality", GenerationConfig.quality),
|
| 367 |
|
| 368 |
-
#
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
teacache_threshold=params.get("teacache_threshold", 0.05),
|
| 373 |
-
|
| 374 |
|
| 375 |
-
#
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
lora_model_weight_file=params.get("lora_model_weight_file", ""),
|
| 382 |
-
lora_model_trigger=params.get("lora_model_trigger", ""),
|
| 383 |
).validate_and_adjust()
|
| 384 |
|
| 385 |
-
#logger.debug(f"Global request settings:")
|
| 386 |
-
#pprint.pprint(config)
|
| 387 |
-
|
| 388 |
try:
|
| 389 |
-
with torch.amp.
|
| 390 |
-
# Set random seeds
|
| 391 |
random.seed(config.seed)
|
| 392 |
np.random.seed(config.seed)
|
| 393 |
torch.manual_seed(config.seed)
|
| 394 |
-
generator = torch.Generator(device='cuda')
|
| 395 |
-
generator = generator.manual_seed(config.seed)
|
| 396 |
-
|
| 397 |
-
# Configure enhance-a-video
|
| 398 |
-
#if config.enable_enhance_a_video:
|
| 399 |
-
# enable_enhance()
|
| 400 |
-
# set_enhance_weight(config.enhance_a_video_weight)
|
| 401 |
|
| 402 |
-
#
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
# Timestep for decoding VAE noise: the timestep at which generated video is decoded
|
| 420 |
-
"decode_timestep": 0.05,
|
| 421 |
-
|
| 422 |
-
# Noise level for decoding VAE noise: the interpolation factor between random noise and denoised latents at the decode timestep
|
| 423 |
-
"decode_noise_scale": 0.025,
|
| 424 |
-
}
|
| 425 |
-
#logger.info(f"Video model generation settings:")
|
| 426 |
-
#pprint.pprint(generation_kwargs)
|
| 427 |
-
|
| 428 |
-
# Handle LoRA loading/unloading
|
| 429 |
-
if hasattr(self, '_current_lora_model'):
|
| 430 |
-
if self._current_lora_model != (config.lora_model_name, config.lora_model_weight_file):
|
| 431 |
-
# Unload previous LoRA if it exists and is different
|
| 432 |
-
if hasattr(self.text_to_video, 'unload_lora_weights'):
|
| 433 |
-
print("Unloading LoRA weights for the text_to_video pipeline..")
|
| 434 |
-
self.text_to_video.unload_lora_weights()
|
| 435 |
-
|
| 436 |
-
if support_image_prompt and hasattr(self.image_to_video, 'unload_lora_weights'):
|
| 437 |
-
print("Unloading LoRA weights for the image_to_video pipeline..")
|
| 438 |
-
self.image_to_video.unload_lora_weights()
|
| 439 |
-
|
| 440 |
-
if config.lora_model_name:
|
| 441 |
-
# Load new LoRA
|
| 442 |
-
if hasattr(self.text_to_video, 'load_lora_weights'):
|
| 443 |
-
print("Loading LoRA weights for the text_to_video pipeline..")
|
| 444 |
-
self.text_to_video.load_lora_weights(
|
| 445 |
-
config.lora_model_name,
|
| 446 |
-
weight_name=config.lora_model_weight_file if config.lora_model_weight_file else None,
|
| 447 |
-
token=hf_token,
|
| 448 |
-
)
|
| 449 |
-
if support_image_prompt and hasattr(self.image_to_video, 'load_lora_weights'):
|
| 450 |
-
print("Loading LoRA weights for the image_to_video pipeline..")
|
| 451 |
-
self.image_to_video.load_lora_weights(
|
| 452 |
-
config.lora_model_name,
|
| 453 |
-
weight_name=config.lora_model_weight_file if config.lora_model_weight_file else None,
|
| 454 |
-
token=hf_token,
|
| 455 |
)
|
| 456 |
-
|
| 457 |
-
|
| 458 |
-
# Modify prompt if trigger word is provided
|
| 459 |
-
if config.lora_model_trigger:
|
| 460 |
-
generation_kwargs["prompt"] = f"{config.lora_model_trigger} {generation_kwargs['prompt']}"
|
| 461 |
-
|
| 462 |
-
#enhance_a_video_config = EnhanceAVideoConfig(
|
| 463 |
-
# weight=config.enhance_a_video_weight if config.enable_enhance_a_video else 0.0,
|
| 464 |
-
# # doing some testing
|
| 465 |
-
# num_frames_callback=lambda: (8 + 1),
|
| 466 |
-
# # num_frames_callback=lambda: config.num_frames,
|
| 467 |
-
# # num_frames_callback=lambda: (config.num_frames - 1),
|
| 468 |
-
#
|
| 469 |
-
# _attention_type=1
|
| 470 |
-
#)
|
| 471 |
|
| 472 |
-
#
|
| 473 |
-
if
|
| 474 |
-
|
| 475 |
-
|
| 476 |
-
|
| 477 |
-
|
| 478 |
-
|
| 479 |
-
|
| 480 |
-
|
| 481 |
-
|
| 482 |
-
|
| 483 |
-
|
| 484 |
-
|
| 485 |
-
|
| 486 |
-
|
| 487 |
-
|
| 488 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 489 |
try:
|
| 490 |
loop = asyncio.get_event_loop()
|
| 491 |
except RuntimeError:
|
| 492 |
loop = asyncio.new_event_loop()
|
| 493 |
asyncio.set_event_loop(loop)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 494 |
|
| 495 |
-
|
| 496 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 497 |
torch.cuda.empty_cache()
|
| 498 |
-
torch.cuda.reset_peak_memory_stats()
|
| 499 |
gc.collect()
|
| 500 |
|
| 501 |
return {
|
|
@@ -503,8 +663,10 @@ class EndpointHandler:
|
|
| 503 |
"content-type": "video/mp4",
|
| 504 |
"metadata": metadata
|
| 505 |
}
|
| 506 |
-
|
| 507 |
except Exception as e:
|
| 508 |
-
|
| 509 |
-
|
| 510 |
-
|
|
|
|
|
|
|
|
|
| 1 |
from dataclasses import dataclass
|
| 2 |
from pathlib import Path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
import logging
|
| 4 |
+
import base64
|
| 5 |
import random
|
| 6 |
+
import gc
|
| 7 |
import os
|
| 8 |
import numpy as np
|
| 9 |
import torch
|
| 10 |
+
from typing import Dict, Any, Optional, List, Union, Tuple
|
| 11 |
+
import json
|
| 12 |
+
from safetensors import safe_open
|
| 13 |
+
|
| 14 |
+
from ltx_video.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
|
| 15 |
+
from ltx_video.models.transformers.transformer3d import Transformer3DModel
|
| 16 |
+
from ltx_video.models.transformers.symmetric_patchifier import SymmetricPatchifier
|
| 17 |
+
from ltx_video.schedulers.rf import RectifiedFlowScheduler, TimestepShifter
|
| 18 |
+
from ltx_video.pipelines.pipeline_ltx_video import ConditioningItem, LTXVideoPipeline
|
| 19 |
+
from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy
|
| 20 |
+
from transformers import T5EncoderModel, T5Tokenizer, AutoModelForCausalLM, AutoProcessor, AutoTokenizer
|
| 21 |
|
| 22 |
from varnish import Varnish
|
| 23 |
from varnish.utils import is_truthy, process_input_image
|
|
|
|
| 26 |
logging.basicConfig(level=logging.INFO)
|
| 27 |
logger = logging.getLogger(__name__)
|
| 28 |
|
|
|
|
| 29 |
# Get token from environment
|
| 30 |
hf_token = os.getenv("HF_API_TOKEN")
|
| 31 |
|
| 32 |
# Constraints
|
| 33 |
MAX_LARGE_SIDE = 1280
|
| 34 |
+
MAX_SMALL_SIDE = 768 # should be 720 but it must be divisible by 32
|
| 35 |
+
MAX_FRAMES = (8 * 21) + 1 # visual glitches appear after about 169 frames, so we cap it
|
| 36 |
|
| 37 |
# Check environment variable for pipeline support
|
| 38 |
support_image_prompt = is_truthy(os.getenv("SUPPORT_INPUT_IMAGE_PROMPT"))
|
|
|
|
| 46 |
negative_prompt: str = "saturated, highlight, overexposed, highlighted, overlit, shaking, too bright, worst quality, inconsistent motion, blurry, jittery, distorted, cropped, watermarked, watermark, logo, subtitle, subtitles, lowres"
|
| 47 |
|
| 48 |
# video model settings (will be used during generation of the initial raw video clip)
|
| 49 |
+
width: int = 1216 # 768
|
| 50 |
+
height: int = 704 # 416
|
|
|
|
|
|
|
| 51 |
|
| 52 |
# this is a hack to fool LTX-Video into believing our input image is an actual video frame with poor encoding quality
|
| 53 |
# after a quick benchmark using the value 70 seems like a sweet spot
|
|
|
|
| 58 |
# visual glitches appear after about 169 frames, so we don't need more actually
|
| 59 |
num_frames: int = (8 * 14) + 1
|
| 60 |
|
| 61 |
+
# values between 3.0 and 4.0 are nice
|
| 62 |
+
guidance_scale: float = 3.0
|
| 63 |
|
| 64 |
num_inference_steps: int = 8
|
| 65 |
|
|
|
|
| 67 |
seed: int = -1 # -1 means random seed
|
| 68 |
|
| 69 |
# varnish settings (will be used for post-processing after the raw video clip has been generated
|
| 70 |
+
fps: int = 30 # FPS of the final video (only applied at the very end, when converting to mp4)
|
| 71 |
+
double_num_frames: bool = False # if True, the number of frames will be multiplied by 2 using RIFE
|
| 72 |
+
super_resolution: bool = False # if True, the resolution will be multiplied by 2 using Real_ESRGAN
|
| 73 |
|
| 74 |
+
grain_amount: float = 0.0 # be careful, adding film grain can negatively impact video compression
|
| 75 |
|
| 76 |
# audio settings
|
| 77 |
enable_audio: bool = False # Whether to generate audio
|
| 78 |
audio_prompt: str = "" # Text prompt for audio generation
|
| 79 |
+
audio_negative_prompt: str = "voices, voice, talking, speaking, speech" # Negative prompt for audio generation
|
| 80 |
|
| 81 |
# The range of the CRF scale is 0–51, where:
|
| 82 |
# 0 is lossless (for 8 bit only, for 10 bit use -qp 0)
|
|
|
|
| 88 |
# The range is exponential, so increasing the CRF value +6 results in roughly half the bitrate / file size, while -6 leads to roughly twice the bitrate.
|
| 89 |
quality: int = 18
|
| 90 |
|
| 91 |
+
# STG (Spatiotemporal Guidance) settings
|
| 92 |
+
stg_scale: float = 0.0
|
| 93 |
+
stg_rescale: float = 1.0
|
| 94 |
+
stg_mode: str = "attention_values" # Can be "attention_values", "attention_skip", "residual", or "transformer_block"
|
| 95 |
|
| 96 |
+
# VAE noise augmentation
|
| 97 |
+
decode_timestep: float = 0.05
|
| 98 |
+
decode_noise_scale: float = 0.025
|
| 99 |
|
| 100 |
+
# Other advanced settings
|
| 101 |
+
image_cond_noise_scale: float = 0.15
|
| 102 |
+
mixed_precision: bool = True # Use mixed precision for inference
|
| 103 |
+
stochastic_sampling: bool = True # Use stochastic sampling
|
| 104 |
+
|
| 105 |
+
# Sampling settings
|
| 106 |
+
sampler: Optional[str] = "from_checkpoint" # "uniform" or "linear-quadratic" or None (use default from checkpoint)
|
| 107 |
+
|
| 108 |
+
# Prompt enhancement
|
| 109 |
+
enhance_prompt: bool = False # Whether to enhance the prompt using an LLM
|
| 110 |
+
prompt_enhancement_words_threshold: int = 50 # Enhance prompt only if it has fewer words than this
|
| 111 |
|
| 112 |
def validate_and_adjust(self) -> 'GenerationConfig':
|
| 113 |
"""Validate and adjust parameters to meet constraints"""
|
|
|
|
| 115 |
if not ((self.width == MAX_LARGE_SIDE and self.height == MAX_SMALL_SIDE) or
|
| 116 |
(self.width == MAX_SMALL_SIDE and self.height == MAX_LARGE_SIDE)):
|
| 117 |
# For other resolutions, ensure total pixels don't exceed max
|
| 118 |
+
MAX_TOTAL_PIXELS = MAX_SMALL_SIDE * MAX_LARGE_SIDE # or 921600 = 1280 * 720
|
| 119 |
|
| 120 |
# If total pixels exceed maximum, scale down proportionally
|
| 121 |
total_pixels = self.width * self.height
|
|
|
|
| 135 |
# Set random seed if not specified
|
| 136 |
if self.seed == -1:
|
| 137 |
self.seed = random.randint(0, 2**32 - 1)
|
|
|
|
|
|
|
| 138 |
|
| 139 |
+
# Set up STG parameters
|
| 140 |
+
if self.stg_mode.lower() == "stg_av" or self.stg_mode.lower() == "attention_values":
|
| 141 |
+
self.stg_mode = "attention_values"
|
| 142 |
+
elif self.stg_mode.lower() == "stg_as" or self.stg_mode.lower() == "attention_skip":
|
| 143 |
+
self.stg_mode = "attention_skip"
|
| 144 |
+
elif self.stg_mode.lower() == "stg_r" or self.stg_mode.lower() == "residual":
|
| 145 |
+
self.stg_mode = "residual"
|
| 146 |
+
elif self.stg_mode.lower() == "stg_t" or self.stg_mode.lower() == "transformer_block":
|
| 147 |
+
self.stg_mode = "transformer_block"
|
| 148 |
+
|
| 149 |
+
# Check if we should enhance the prompt
|
| 150 |
+
if self.enhance_prompt and self.prompt:
|
| 151 |
+
prompt_word_count = len(self.prompt.split())
|
| 152 |
+
if prompt_word_count >= self.prompt_enhancement_words_threshold:
|
| 153 |
+
logger.info(f"Prompt has {prompt_word_count} words, which exceeds the threshold of {self.prompt_enhancement_words_threshold}. Prompt enhancement disabled.")
|
| 154 |
+
self.enhance_prompt = False
|
| 155 |
|
| 156 |
+
return self
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
|
| 158 |
+
def load_image_to_tensor_with_resize_and_crop(
|
| 159 |
+
image_input: Union[str, bytes],
|
| 160 |
+
target_height: int = 704,
|
| 161 |
+
target_width: int = 1216,
|
| 162 |
+
quality: int = 100
|
| 163 |
+
) -> torch.Tensor:
|
| 164 |
+
"""Load and process an image into a tensor.
|
| 165 |
+
|
| 166 |
+
Args:
|
| 167 |
+
image_input: Either a file path (str) or image data (bytes)
|
| 168 |
+
target_height: Desired height of output tensor
|
| 169 |
+
target_width: Desired width of output tensor
|
| 170 |
+
quality: JPEG quality to use when re-encoding (to simulate lower quality images)
|
| 171 |
+
"""
|
| 172 |
+
from PIL import Image
|
| 173 |
+
import io
|
| 174 |
+
import numpy as np
|
| 175 |
+
|
| 176 |
+
# Handle base64 data URI
|
| 177 |
+
if isinstance(image_input, str) and image_input.startswith('data:'):
|
| 178 |
+
header, encoded = image_input.split(",", 1)
|
| 179 |
+
image_data = base64.b64decode(encoded)
|
| 180 |
+
image = Image.open(io.BytesIO(image_data)).convert("RGB")
|
| 181 |
+
# Handle raw bytes
|
| 182 |
+
elif isinstance(image_input, bytes):
|
| 183 |
+
image = Image.open(io.BytesIO(image_input)).convert("RGB")
|
| 184 |
+
# Handle file path
|
| 185 |
+
elif isinstance(image_input, str):
|
| 186 |
+
image = Image.open(image_input).convert("RGB")
|
| 187 |
+
else:
|
| 188 |
+
raise ValueError("image_input must be either a file path, bytes, or base64 data URI")
|
| 189 |
+
|
| 190 |
+
# Apply JPEG compression if quality < 100 (to simulate a video frame)
|
| 191 |
+
if quality < 100:
|
| 192 |
+
buffer = io.BytesIO()
|
| 193 |
+
image.save(buffer, format="JPEG", quality=quality)
|
| 194 |
+
buffer.seek(0)
|
| 195 |
+
image = Image.open(buffer).convert("RGB")
|
| 196 |
+
|
| 197 |
+
input_width, input_height = image.size
|
| 198 |
+
aspect_ratio_target = target_width / target_height
|
| 199 |
+
aspect_ratio_frame = input_width / input_height
|
| 200 |
+
if aspect_ratio_frame > aspect_ratio_target:
|
| 201 |
+
new_width = int(input_height * aspect_ratio_target)
|
| 202 |
+
new_height = input_height
|
| 203 |
+
x_start = (input_width - new_width) // 2
|
| 204 |
+
y_start = 0
|
| 205 |
+
else:
|
| 206 |
+
new_width = input_width
|
| 207 |
+
new_height = int(input_width / aspect_ratio_target)
|
| 208 |
+
x_start = 0
|
| 209 |
+
y_start = (input_height - new_height) // 2
|
| 210 |
+
|
| 211 |
+
image = image.crop((x_start, y_start, x_start + new_width, y_start + new_height))
|
| 212 |
+
image = image.resize((target_width, target_height))
|
| 213 |
+
frame_tensor = torch.tensor(np.array(image)).permute(2, 0, 1).float()
|
| 214 |
+
frame_tensor = (frame_tensor / 127.5) - 1.0
|
| 215 |
+
# Create 5D tensor: (batch_size=1, channels=3, num_frames=1, height, width)
|
| 216 |
+
return frame_tensor.unsqueeze(0).unsqueeze(2)
|
| 217 |
+
|
| 218 |
+
def calculate_padding(
|
| 219 |
+
source_height: int, source_width: int, target_height: int, target_width: int
|
| 220 |
+
) -> tuple[int, int, int, int]:
|
| 221 |
+
"""Calculate padding to reach target dimensions"""
|
| 222 |
+
# Calculate total padding needed
|
| 223 |
+
pad_height = target_height - source_height
|
| 224 |
+
pad_width = target_width - source_width
|
| 225 |
+
|
| 226 |
+
# Calculate padding for each side
|
| 227 |
+
pad_top = pad_height // 2
|
| 228 |
+
pad_bottom = pad_height - pad_top # Handles odd padding
|
| 229 |
+
pad_left = pad_width // 2
|
| 230 |
+
pad_right = pad_width - pad_left # Handles odd padding
|
| 231 |
+
|
| 232 |
+
# Return padded tensor
|
| 233 |
+
# Padding format is (left, right, top, bottom)
|
| 234 |
+
padding = (pad_left, pad_right, pad_top, pad_bottom)
|
| 235 |
+
return padding
|
| 236 |
+
|
| 237 |
+
def prepare_conditioning(
|
| 238 |
+
conditioning_media_paths: List[str],
|
| 239 |
+
conditioning_strengths: List[float],
|
| 240 |
+
conditioning_start_frames: List[int],
|
| 241 |
+
height: int,
|
| 242 |
+
width: int,
|
| 243 |
+
num_frames: int,
|
| 244 |
+
input_image_quality: int = 100,
|
| 245 |
+
pipeline: Optional[LTXVideoPipeline] = None,
|
| 246 |
+
) -> Optional[List[ConditioningItem]]:
|
| 247 |
+
"""Prepare conditioning items based on input media paths and their parameters"""
|
| 248 |
+
conditioning_items = []
|
| 249 |
+
for path, strength, start_frame in zip(
|
| 250 |
+
conditioning_media_paths, conditioning_strengths, conditioning_start_frames
|
| 251 |
+
):
|
| 252 |
+
# Load and process the conditioning image
|
| 253 |
+
frame_tensor = load_image_to_tensor_with_resize_and_crop(
|
| 254 |
+
path, height, width, quality=input_image_quality
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
# Trim frame count if needed
|
| 258 |
+
if pipeline:
|
| 259 |
+
frame_count = 1 # For image inputs, it's always 1
|
| 260 |
+
frame_count = pipeline.trim_conditioning_sequence(
|
| 261 |
+
start_frame, frame_count, num_frames
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
conditioning_items.append(
|
| 265 |
+
ConditioningItem(frame_tensor, start_frame, strength)
|
| 266 |
)
|
| 267 |
|
| 268 |
+
return conditioning_items
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 269 |
|
| 270 |
+
def create_ltx_video_pipeline(
|
| 271 |
+
config: GenerationConfig,
|
| 272 |
+
device: str = "cuda"
|
| 273 |
+
) -> LTXVideoPipeline:
|
| 274 |
+
"""Create and configure the LTX video pipeline"""
|
|
|
|
| 275 |
|
| 276 |
+
ckpt_path = "/repository/ltxv-2b-0.9.6-distilled-04-25.safetensors"
|
| 277 |
+
|
| 278 |
+
# Get allowed inference steps from config if available
|
| 279 |
+
allowed_inference_steps = None
|
| 280 |
+
|
| 281 |
+
assert os.path.exists(
|
| 282 |
+
ckpt_path
|
| 283 |
+
), f"Ckpt path provided (--ckpt_path) {ckpt_path} does not exist"
|
| 284 |
+
|
| 285 |
+
with safe_open(ckpt_path, framework="pt") as f:
|
| 286 |
+
metadata = f.metadata()
|
| 287 |
+
config_str = metadata.get("config")
|
| 288 |
+
configs = json.loads(config_str)
|
| 289 |
+
allowed_inference_steps = configs.get("allowed_inference_steps", None)
|
| 290 |
+
|
| 291 |
+
# Initialize model components
|
| 292 |
+
vae = CausalVideoAutoencoder.from_pretrained(ckpt_path)
|
| 293 |
+
transformer = Transformer3DModel.from_pretrained(ckpt_path)
|
| 294 |
+
|
| 295 |
+
# Use constructor if sampler is specified, otherwise use from_pretrained
|
| 296 |
+
if config.sampler:
|
| 297 |
+
scheduler = RectifiedFlowScheduler(
|
| 298 |
+
sampler=("Uniform" if config.sampler.lower() == "uniform" else "LinearQuadratic")
|
| 299 |
+
)
|
| 300 |
+
else:
|
| 301 |
+
scheduler = RectifiedFlowScheduler.from_pretrained(ckpt_path)
|
| 302 |
+
|
| 303 |
+
text_encoder = T5EncoderModel.from_pretrained("/repository/text_encoder")
|
| 304 |
+
patchifier = SymmetricPatchifier(patch_size=1)
|
| 305 |
+
tokenizer = T5Tokenizer.from_pretrained("/repository/tokenizer")
|
| 306 |
+
|
| 307 |
+
# Move models to the correct device
|
| 308 |
+
vae = vae.to(device)
|
| 309 |
+
transformer = transformer.to(device)
|
| 310 |
+
text_encoder = text_encoder.to(device)
|
| 311 |
+
|
| 312 |
+
# Set up precision
|
| 313 |
+
vae = vae.to(torch.bfloat16)
|
| 314 |
+
transformer = transformer.to(torch.bfloat16)
|
| 315 |
+
text_encoder = text_encoder.to(torch.bfloat16)
|
| 316 |
+
|
| 317 |
+
# Initialize prompt enhancer components if needed
|
| 318 |
+
prompt_enhancer_components = {
|
| 319 |
+
"prompt_enhancer_image_caption_model": None,
|
| 320 |
+
"prompt_enhancer_image_caption_processor": None,
|
| 321 |
+
"prompt_enhancer_llm_model": None,
|
| 322 |
+
"prompt_enhancer_llm_tokenizer": None
|
| 323 |
+
}
|
| 324 |
+
|
| 325 |
+
if config.enhance_prompt:
|
| 326 |
+
try:
|
| 327 |
+
# Use default models or ones specified by config
|
| 328 |
+
prompt_enhancer_image_caption_model = AutoModelForCausalLM.from_pretrained(
|
| 329 |
+
"MiaoshouAI/Florence-2-large-PromptGen-v2.0",
|
| 330 |
+
trust_remote_code=True
|
| 331 |
+
)
|
| 332 |
+
prompt_enhancer_image_caption_processor = AutoProcessor.from_pretrained(
|
| 333 |
+
"MiaoshouAI/Florence-2-large-PromptGen-v2.0",
|
| 334 |
+
trust_remote_code=True
|
| 335 |
+
)
|
| 336 |
+
prompt_enhancer_llm_model = AutoModelForCausalLM.from_pretrained(
|
| 337 |
+
"unsloth/Llama-3.2-3B-Instruct",
|
| 338 |
+
torch_dtype="bfloat16",
|
| 339 |
+
)
|
| 340 |
+
prompt_enhancer_llm_tokenizer = AutoTokenizer.from_pretrained(
|
| 341 |
+
"unsloth/Llama-3.2-3B-Instruct",
|
| 342 |
+
)
|
| 343 |
|
| 344 |
+
prompt_enhancer_components = {
|
| 345 |
+
"prompt_enhancer_image_caption_model": prompt_enhancer_image_caption_model,
|
| 346 |
+
"prompt_enhancer_image_caption_processor": prompt_enhancer_image_caption_processor,
|
| 347 |
+
"prompt_enhancer_llm_model": prompt_enhancer_llm_model,
|
| 348 |
+
"prompt_enhancer_llm_tokenizer": prompt_enhancer_llm_tokenizer
|
| 349 |
+
}
|
| 350 |
+
except Exception as e:
|
| 351 |
+
logger.warning(f"Failed to load prompt enhancer models: {e}")
|
| 352 |
+
config.enhance_prompt = False
|
| 353 |
+
|
| 354 |
+
# Construct the pipeline
|
| 355 |
+
pipeline = LTXVideoPipeline(
|
| 356 |
+
transformer=transformer,
|
| 357 |
+
patchifier=patchifier,
|
| 358 |
+
text_encoder=text_encoder,
|
| 359 |
+
tokenizer=tokenizer,
|
| 360 |
+
scheduler=scheduler,
|
| 361 |
+
vae=vae,
|
| 362 |
+
allowed_inference_steps=allowed_inference_steps,
|
| 363 |
+
**prompt_enhancer_components
|
| 364 |
+
)
|
| 365 |
+
|
| 366 |
+
return pipeline
|
|
|
|
| 367 |
|
| 368 |
+
class EndpointHandler:
|
| 369 |
+
"""Handler for the LTX Video endpoint"""
|
| 370 |
+
|
| 371 |
+
def __init__(self, model_path: str = "/repository/"):
|
| 372 |
+
"""Initialize the endpoint handler
|
| 373 |
+
|
| 374 |
+
Args:
|
| 375 |
+
model_path: Path to model weights (not used, as weights are in current directory)
|
| 376 |
+
"""
|
| 377 |
+
# Enable TF32 for potential speedup on Ampere GPUs
|
| 378 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 379 |
+
|
| 380 |
# Initialize Varnish for post-processing
|
| 381 |
self.varnish = Varnish(
|
| 382 |
device="cuda",
|
| 383 |
model_base_dir="/repository/varnish",
|
| 384 |
+
enable_mmaudio=False, # Disable audio generation for now, since it is broken
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 385 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 386 |
|
| 387 |
+
# The actual LTX pipeline will be loaded during inference to save memory
|
| 388 |
+
self.pipeline = None
|
| 389 |
+
|
| 390 |
+
# Perform warm-up inference
|
| 391 |
+
logger.info("Performing warm-up inference...")
|
| 392 |
+
self._warmup()
|
| 393 |
+
logger.info("Warm-up completed!")
|
| 394 |
+
|
| 395 |
+
def _warmup(self):
|
| 396 |
+
"""Perform a warm-up inference to prepare the model for future requests"""
|
| 397 |
try:
|
| 398 |
+
# Create a simple test configuration
|
| 399 |
+
test_config = GenerationConfig(
|
| 400 |
+
prompt="an astronaut is riding a cow in the desert, during golden hour",
|
| 401 |
+
negative_prompt="worst quality, lowres",
|
| 402 |
+
width=768, # Using smaller resolution for faster warm-up
|
| 403 |
+
height=416,
|
| 404 |
+
num_frames=33, # Just enough frames for a valid video
|
| 405 |
+
guidance_scale=1.0,
|
| 406 |
+
num_inference_steps=4, # Fewer steps for faster warm-up
|
| 407 |
+
seed=42, # Fixed seed for consistent warm-up
|
| 408 |
+
fps=16, # Lower FPS for faster processing
|
| 409 |
+
enable_audio=False, # No audio for warm-up
|
| 410 |
+
mixed_precision=True,
|
| 411 |
+
).validate_and_adjust()
|
| 412 |
|
| 413 |
+
# Create the pipeline if it doesn't exist
|
| 414 |
+
if self.pipeline is None:
|
| 415 |
+
self.pipeline = create_ltx_video_pipeline(test_config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 416 |
|
| 417 |
+
# Run a quick inference
|
| 418 |
+
with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16), torch.no_grad():
|
| 419 |
+
# Set seeds for reproducibility
|
| 420 |
+
random.seed(test_config.seed)
|
| 421 |
+
np.random.seed(test_config.seed)
|
| 422 |
+
torch.manual_seed(test_config.seed)
|
| 423 |
+
generator = torch.Generator(device='cuda').manual_seed(test_config.seed)
|
| 424 |
+
|
| 425 |
+
# Generate video
|
| 426 |
+
result = self.pipeline(
|
| 427 |
+
height=test_config.height,
|
| 428 |
+
width=test_config.width,
|
| 429 |
+
num_frames=test_config.num_frames,
|
| 430 |
+
frame_rate=test_config.fps,
|
| 431 |
+
prompt=test_config.prompt,
|
| 432 |
+
negative_prompt=test_config.negative_prompt,
|
| 433 |
+
guidance_scale=test_config.guidance_scale,
|
| 434 |
+
num_inference_steps=test_config.num_inference_steps,
|
| 435 |
+
generator=generator,
|
| 436 |
+
output_type="pt",
|
| 437 |
+
mixed_precision=test_config.mixed_precision,
|
| 438 |
+
is_video=True,
|
| 439 |
+
vae_per_channel_normalize=True,
|
| 440 |
+
)
|
| 441 |
+
|
| 442 |
+
# Just get the frames without full processing (faster warm-up)
|
| 443 |
+
frames = result.images
|
| 444 |
+
|
| 445 |
+
# Clean up
|
| 446 |
+
del result
|
| 447 |
+
torch.cuda.empty_cache()
|
| 448 |
+
gc.collect()
|
| 449 |
+
|
| 450 |
+
logger.info(f"Warm-up successful! Generated {frames.shape[2]} frames at {frames.shape[3]}x{frames.shape[4]}")
|
| 451 |
+
|
| 452 |
except Exception as e:
|
| 453 |
+
# Log the error but don't fail initialization
|
| 454 |
+
import traceback
|
| 455 |
+
error_message = f"Warm-up failed (but this is non-critical): {str(e)}\n{traceback.format_exc()}"
|
| 456 |
+
logger.warning(error_message)
|
| 457 |
+
|
| 458 |
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
| 459 |
+
"""Process inference requests
|
| 460 |
|
| 461 |
Args:
|
| 462 |
+
data: Request data containing inputs and parameters
|
| 463 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 464 |
Returns:
|
| 465 |
+
Dictionary with generated video and metadata
|
|
|
|
|
|
|
|
|
|
| 466 |
"""
|
| 467 |
+
# Extract inputs and parameters
|
| 468 |
+
inputs = data.get("inputs", {})
|
| 469 |
|
| 470 |
+
# Support both formats:
|
| 471 |
+
# 1. {"inputs": {"prompt": "...", "image": "..."}}
|
| 472 |
+
# 2. {"inputs": "..."} (prompt only)
|
| 473 |
+
if isinstance(inputs, str):
|
| 474 |
+
input_prompt = inputs
|
| 475 |
+
input_image = None
|
| 476 |
+
else:
|
| 477 |
+
input_prompt = inputs.get("prompt", "")
|
| 478 |
+
input_image = inputs.get("image")
|
| 479 |
|
| 480 |
+
params = data.get("parameters", {})
|
| 481 |
+
|
| 482 |
+
if not input_prompt and not input_image:
|
| 483 |
raise ValueError("Either prompt or image must be provided")
|
| 484 |
+
|
|
|
|
|
|
|
|
|
|
| 485 |
# Create and validate configuration
|
| 486 |
config = GenerationConfig(
|
| 487 |
# general content settings
|
| 488 |
prompt=input_prompt,
|
| 489 |
negative_prompt=params.get("negative_prompt", GenerationConfig.negative_prompt),
|
| 490 |
+
|
| 491 |
+
# video model settings
|
| 492 |
width=params.get("width", GenerationConfig.width),
|
| 493 |
height=params.get("height", GenerationConfig.height),
|
| 494 |
input_image_quality=params.get("input_image_quality", GenerationConfig.input_image_quality),
|
| 495 |
num_frames=params.get("num_frames", GenerationConfig.num_frames),
|
| 496 |
guidance_scale=params.get("guidance_scale", GenerationConfig.guidance_scale),
|
| 497 |
num_inference_steps=params.get("num_inference_steps", GenerationConfig.num_inference_steps),
|
| 498 |
+
|
| 499 |
+
# STG settings
|
| 500 |
+
stg_scale=params.get("stg_scale", GenerationConfig.stg_scale),
|
| 501 |
+
stg_rescale=params.get("stg_rescale", GenerationConfig.stg_rescale),
|
| 502 |
+
stg_mode=params.get("stg_mode", GenerationConfig.stg_mode),
|
| 503 |
+
|
| 504 |
+
# VAE noise settings
|
| 505 |
+
decode_timestep=params.get("decode_timestep", GenerationConfig.decode_timestep),
|
| 506 |
+
decode_noise_scale=params.get("decode_noise_scale", GenerationConfig.decode_noise_scale),
|
| 507 |
+
image_cond_noise_scale=params.get("image_cond_noise_scale", GenerationConfig.image_cond_noise_scale),
|
| 508 |
+
|
| 509 |
# reproducible generation settings
|
| 510 |
seed=params.get("seed", GenerationConfig.seed),
|
| 511 |
|
| 512 |
+
# varnish settings
|
| 513 |
+
fps=params.get("fps", GenerationConfig.fps),
|
| 514 |
+
double_num_frames=params.get("double_num_frames", GenerationConfig.double_num_frames),
|
| 515 |
+
super_resolution=params.get("super_resolution", GenerationConfig.super_resolution),
|
| 516 |
grain_amount=params.get("grain_amount", GenerationConfig.grain_amount),
|
| 517 |
enable_audio=params.get("enable_audio", GenerationConfig.enable_audio),
|
| 518 |
audio_prompt=params.get("audio_prompt", GenerationConfig.audio_prompt),
|
| 519 |
audio_negative_prompt=params.get("audio_negative_prompt", GenerationConfig.audio_negative_prompt),
|
| 520 |
quality=params.get("quality", GenerationConfig.quality),
|
| 521 |
|
| 522 |
+
# advanced settings
|
| 523 |
+
mixed_precision=params.get("mixed_precision", GenerationConfig.mixed_precision),
|
| 524 |
+
stochastic_sampling=params.get("stochastic_sampling", GenerationConfig.stochastic_sampling),
|
| 525 |
+
sampler=params.get("sampler", GenerationConfig.sampler),
|
|
|
|
|
|
|
| 526 |
|
| 527 |
+
# prompt enhancement
|
| 528 |
+
enhance_prompt=params.get("enhance_prompt", GenerationConfig.enhance_prompt),
|
| 529 |
+
prompt_enhancement_words_threshold=params.get(
|
| 530 |
+
"prompt_enhancement_words_threshold",
|
| 531 |
+
GenerationConfig.prompt_enhancement_words_threshold
|
| 532 |
+
),
|
|
|
|
|
|
|
| 533 |
).validate_and_adjust()
|
| 534 |
|
|
|
|
|
|
|
|
|
|
| 535 |
try:
|
| 536 |
+
with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16), torch.no_grad():
|
| 537 |
+
# Set random seeds for reproducibility
|
| 538 |
random.seed(config.seed)
|
| 539 |
np.random.seed(config.seed)
|
| 540 |
torch.manual_seed(config.seed)
|
| 541 |
+
generator = torch.Generator(device='cuda').manual_seed(config.seed)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 542 |
|
| 543 |
+
# Create pipeline if not already created
|
| 544 |
+
if self.pipeline is None:
|
| 545 |
+
self.pipeline = create_ltx_video_pipeline(config)
|
| 546 |
+
|
| 547 |
+
# Prepare conditioning items if an image is provided
|
| 548 |
+
conditioning_items = None
|
| 549 |
+
if input_image:
|
| 550 |
+
conditioning_items = [
|
| 551 |
+
ConditioningItem(
|
| 552 |
+
load_image_to_tensor_with_resize_and_crop(
|
| 553 |
+
input_image,
|
| 554 |
+
config.height,
|
| 555 |
+
config.width,
|
| 556 |
+
quality=config.input_image_quality
|
| 557 |
+
),
|
| 558 |
+
0, # Start frame
|
| 559 |
+
1.0 # Conditioning strength
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 560 |
)
|
| 561 |
+
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 562 |
|
| 563 |
+
# Set up spatiotemporal guidance strategy
|
| 564 |
+
if config.stg_mode == "attention_values":
|
| 565 |
+
skip_layer_strategy = SkipLayerStrategy.AttentionValues
|
| 566 |
+
elif config.stg_mode == "attention_skip":
|
| 567 |
+
skip_layer_strategy = SkipLayerStrategy.AttentionSkip
|
| 568 |
+
elif config.stg_mode == "residual":
|
| 569 |
+
skip_layer_strategy = SkipLayerStrategy.Residual
|
| 570 |
+
elif config.stg_mode == "transformer_block":
|
| 571 |
+
skip_layer_strategy = SkipLayerStrategy.TransformerBlock
|
| 572 |
+
|
| 573 |
+
# Generate video with LTX pipeline
|
| 574 |
+
result = self.pipeline(
|
| 575 |
+
height=config.height,
|
| 576 |
+
width=config.width,
|
| 577 |
+
num_frames=config.num_frames,
|
| 578 |
+
frame_rate=config.fps,
|
| 579 |
+
prompt=config.prompt,
|
| 580 |
+
negative_prompt=config.negative_prompt,
|
| 581 |
+
guidance_scale=config.guidance_scale,
|
| 582 |
+
num_inference_steps=config.num_inference_steps,
|
| 583 |
+
generator=generator,
|
| 584 |
+
output_type="pt", # Return as PyTorch tensor
|
| 585 |
+
skip_layer_strategy=skip_layer_strategy,
|
| 586 |
+
stg_scale=config.stg_scale,
|
| 587 |
+
do_rescaling=config.stg_rescale != 1.0,
|
| 588 |
+
rescaling_scale=config.stg_rescale,
|
| 589 |
+
conditioning_items=conditioning_items,
|
| 590 |
+
decode_timestep=config.decode_timestep,
|
| 591 |
+
decode_noise_scale=config.decode_noise_scale,
|
| 592 |
+
image_cond_noise_scale=config.image_cond_noise_scale,
|
| 593 |
+
mixed_precision=config.mixed_precision,
|
| 594 |
+
is_video=True,
|
| 595 |
+
vae_per_channel_normalize=True,
|
| 596 |
+
stochastic_sampling=config.stochastic_sampling,
|
| 597 |
+
enhance_prompt=config.enhance_prompt,
|
| 598 |
+
)
|
| 599 |
+
|
| 600 |
+
# Get the generated frames
|
| 601 |
+
frames = result.images
|
| 602 |
+
|
| 603 |
+
# FIX: Convert LTX output format to varnish-compatible format
|
| 604 |
+
# LTX outputs: [batch, channels, frames, height, width]
|
| 605 |
+
# We need: [frames, channels, height, width] for varnish
|
| 606 |
+
frames = frames.squeeze(0) # Remove batch: [channels, frames, height, width]
|
| 607 |
+
frames = frames.permute(1, 0, 2, 3) # Reorder to: [frames, channels, height, width]
|
| 608 |
+
|
| 609 |
+
# Convert from [0, 1] to [0, 255] range
|
| 610 |
+
frames = frames * 255.0
|
| 611 |
+
|
| 612 |
+
# Convert to uint8
|
| 613 |
+
frames = frames.to(torch.uint8)
|
| 614 |
+
|
| 615 |
+
# Process the generated frames with Varnish
|
| 616 |
+
import asyncio
|
| 617 |
try:
|
| 618 |
loop = asyncio.get_event_loop()
|
| 619 |
except RuntimeError:
|
| 620 |
loop = asyncio.new_event_loop()
|
| 621 |
asyncio.set_event_loop(loop)
|
| 622 |
+
|
| 623 |
+
# Process with Varnish for post-processing
|
| 624 |
+
varnish_result = loop.run_until_complete(
|
| 625 |
+
self.varnish(
|
| 626 |
+
frames,
|
| 627 |
+
fps=config.fps,
|
| 628 |
+
double_num_frames=config.double_num_frames,
|
| 629 |
+
super_resolution=config.super_resolution,
|
| 630 |
+
grain_amount=config.grain_amount,
|
| 631 |
+
enable_audio=config.enable_audio,
|
| 632 |
+
audio_prompt=config.audio_prompt or config.prompt,
|
| 633 |
+
audio_negative_prompt=config.audio_negative_prompt,
|
| 634 |
+
)
|
| 635 |
+
)
|
| 636 |
|
| 637 |
+
# Get the final video as a data URI
|
| 638 |
+
video_uri = loop.run_until_complete(
|
| 639 |
+
varnish_result.write(
|
| 640 |
+
type="data-uri",
|
| 641 |
+
quality=config.quality
|
| 642 |
+
)
|
| 643 |
+
)
|
| 644 |
+
|
| 645 |
+
# Prepare metadata about the generated video
|
| 646 |
+
metadata = {
|
| 647 |
+
"width": varnish_result.metadata.width,
|
| 648 |
+
"height": varnish_result.metadata.height,
|
| 649 |
+
"num_frames": varnish_result.metadata.frame_count,
|
| 650 |
+
"fps": varnish_result.metadata.fps,
|
| 651 |
+
"duration": varnish_result.metadata.duration,
|
| 652 |
+
"seed": config.seed,
|
| 653 |
+
"prompt": config.prompt,
|
| 654 |
+
}
|
| 655 |
+
|
| 656 |
+
# Clean up to prevent CUDA OOM errors
|
| 657 |
+
del result
|
| 658 |
torch.cuda.empty_cache()
|
|
|
|
| 659 |
gc.collect()
|
| 660 |
|
| 661 |
return {
|
|
|
|
| 663 |
"content-type": "video/mp4",
|
| 664 |
"metadata": metadata
|
| 665 |
}
|
| 666 |
+
|
| 667 |
except Exception as e:
|
| 668 |
+
# Log the error and reraise
|
| 669 |
+
import traceback
|
| 670 |
+
error_message = f"Error generating video: {str(e)}\n{traceback.format_exc()}"
|
| 671 |
+
logger.error(error_message)
|
| 672 |
+
raise RuntimeError(error_message)
|