| 
							 | 
						from dataclasses import dataclass | 
					
					
						
						| 
							 | 
						from pathlib import Path | 
					
					
						
						| 
							 | 
						import logging | 
					
					
						
						| 
							 | 
						import base64 | 
					
					
						
						| 
							 | 
						import random | 
					
					
						
						| 
							 | 
						import gc | 
					
					
						
						| 
							 | 
						import os | 
					
					
						
						| 
							 | 
						import numpy as np | 
					
					
						
						| 
							 | 
						import torch | 
					
					
						
						| 
							 | 
						from typing import Dict, Any, Optional, List, Union, Tuple | 
					
					
						
						| 
							 | 
						import json | 
					
					
						
						| 
							 | 
						from safetensors import safe_open | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						from ltx_video.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder | 
					
					
						
						| 
							 | 
						from ltx_video.models.transformers.transformer3d import Transformer3DModel | 
					
					
						
						| 
							 | 
						from ltx_video.models.transformers.symmetric_patchifier import SymmetricPatchifier | 
					
					
						
						| 
							 | 
						from ltx_video.schedulers.rf import RectifiedFlowScheduler, TimestepShifter | 
					
					
						
						| 
							 | 
						from ltx_video.pipelines.pipeline_ltx_video import ConditioningItem, LTXVideoPipeline, LTXMultiScalePipeline | 
					
					
						
						| 
							 | 
						from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy | 
					
					
						
						| 
							 | 
						from ltx_video.models.autoencoders.latent_upsampler import LatentUpsampler | 
					
					
						
						| 
							 | 
						from transformers import T5EncoderModel, T5Tokenizer, AutoModelForCausalLM, AutoProcessor, AutoTokenizer | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						from varnish import Varnish | 
					
					
						
						| 
							 | 
						from varnish.utils import is_truthy, process_input_image | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						logging.basicConfig(level=logging.INFO) | 
					
					
						
						| 
							 | 
						logger = logging.getLogger(__name__) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						hf_token = os.getenv("HF_API_TOKEN") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						MAX_LARGE_SIDE = 1280 | 
					
					
						
						| 
							 | 
						MAX_SMALL_SIDE = 768   | 
					
					
						
						| 
							 | 
						MAX_FRAMES = (8 * 21) + 1   | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						DEFAULT_FIRST_PASS_TIMESTEPS = [1.0000, 0.9937, 0.9875, 0.9812, 0.9750, 0.9094, 0.7250] | 
					
					
						
						| 
							 | 
						DEFAULT_SECOND_PASS_TIMESTEPS = [0.9094, 0.7250, 0.4219] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						ALLOWED_TIMESTEPS = [1.0, 0.9937, 0.9875, 0.9812, 0.975, 0.9094, 0.725, 0.4219] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						support_image_prompt = is_truthy(os.getenv("SUPPORT_INPUT_IMAGE_PROMPT")) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def generate_valid_timesteps(num_steps: int, allowed_timesteps: List[float], start_high: bool = True) -> List[float]: | 
					
					
						
						| 
							 | 
						    """Generate valid timesteps by selecting from the allowed timesteps list""" | 
					
					
						
						| 
							 | 
						    if num_steps >= len(allowed_timesteps): | 
					
					
						
						| 
							 | 
						        return allowed_timesteps | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    if num_steps == 1: | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        return [allowed_timesteps[0] if start_high else allowed_timesteps[-1]] | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    if start_high: | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        indices = [] | 
					
					
						
						| 
							 | 
						        for i in range(num_steps): | 
					
					
						
						| 
							 | 
						            idx = int(i * (len(allowed_timesteps) - 1) / (num_steps - 1)) | 
					
					
						
						| 
							 | 
						            indices.append(idx) | 
					
					
						
						| 
							 | 
						        return [allowed_timesteps[i] for i in indices] | 
					
					
						
						| 
							 | 
						    else: | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        return allowed_timesteps[-num_steps:] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						@dataclass | 
					
					
						
						| 
							 | 
						class GenerationConfig: | 
					
					
						
						| 
							 | 
						    """Configuration for video generation""" | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    prompt: str = "" | 
					
					
						
						| 
							 | 
						    negative_prompt: str = "saturated, highlight, overexposed, highlighted, overlit, shaking, too bright, worst quality, inconsistent motion, blurry, jittery, distorted, cropped, watermarked, watermark, logo, subtitle, subtitles, lowres" | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    width: int = 1216  | 
					
					
						
						| 
							 | 
						    height: int = 704  | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    input_image_quality: int = 70 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    num_frames: int = (8 * 14) + 1 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    guidance_scale: float = 3.0 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    num_inference_steps: int = 8 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    pipeline_type: str = "multi-scale"   | 
					
					
						
						| 
							 | 
						    downscale_factor: float = 0.6666666 | 
					
					
						
						| 
							 | 
						    first_pass_timesteps: Optional[List[float]] = None   | 
					
					
						
						| 
							 | 
						    second_pass_timesteps: Optional[List[float]] = None   | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    seed: int = -1   | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    fps: int = 30   | 
					
					
						
						| 
							 | 
						    double_num_frames: bool = False   | 
					
					
						
						| 
							 | 
						    super_resolution: bool = False   | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    grain_amount: float = 0.0   | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    enable_audio: bool = False   | 
					
					
						
						| 
							 | 
						    audio_prompt: str = ""   | 
					
					
						
						| 
							 | 
						    audio_negative_prompt: str = "voices, voice, talking, speaking, speech"   | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    quality: int = 18 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    stg_scale: float = 0.0 | 
					
					
						
						| 
							 | 
						    stg_rescale: float = 1.0 | 
					
					
						
						| 
							 | 
						    stg_mode: str = "attention_values"   | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    decode_timestep: float = 0.05 | 
					
					
						
						| 
							 | 
						    decode_noise_scale: float = 0.025 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    image_cond_noise_scale: float = 0.15 | 
					
					
						
						| 
							 | 
						    mixed_precision: bool = True   | 
					
					
						
						| 
							 | 
						    stochastic_sampling: bool = False   | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    sampler: Optional[str] = "from_checkpoint"   | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    enhance_prompt: bool = False   | 
					
					
						
						| 
							 | 
						    prompt_enhancement_words_threshold: int = 50   | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    def validate_and_adjust(self) -> 'GenerationConfig': | 
					
					
						
						| 
							 | 
						        """Validate and adjust parameters to meet constraints""" | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        if not ((self.width == MAX_LARGE_SIDE and self.height == MAX_SMALL_SIDE) or  | 
					
					
						
						| 
							 | 
						                (self.width == MAX_SMALL_SIDE and self.height == MAX_LARGE_SIDE)): | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            MAX_TOTAL_PIXELS = MAX_SMALL_SIDE * MAX_LARGE_SIDE   | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            total_pixels = self.width * self.height | 
					
					
						
						| 
							 | 
						            if total_pixels > MAX_TOTAL_PIXELS: | 
					
					
						
						| 
							 | 
						                scale = (MAX_TOTAL_PIXELS / total_pixels) ** 0.5 | 
					
					
						
						| 
							 | 
						                self.width = max(128, min(MAX_LARGE_SIDE, round(self.width * scale / 32) * 32)) | 
					
					
						
						| 
							 | 
						                self.height = max(128, min(MAX_LARGE_SIDE, round(self.height * scale / 32) * 32)) | 
					
					
						
						| 
							 | 
						            else: | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                self.width = max(128, min(MAX_LARGE_SIDE, round(self.width / 32) * 32)) | 
					
					
						
						| 
							 | 
						                self.height = max(128, min(MAX_LARGE_SIDE, round(self.height / 32) * 32)) | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        k = (self.num_frames - 1) // 8 | 
					
					
						
						| 
							 | 
						        self.num_frames = min((k * 8) + 1, MAX_FRAMES) | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        if self.seed == -1: | 
					
					
						
						| 
							 | 
						            self.seed = random.randint(0, 2**32 - 1) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        if self.stg_mode.lower() == "stg_av" or self.stg_mode.lower() == "attention_values": | 
					
					
						
						| 
							 | 
						            self.stg_mode = "attention_values" | 
					
					
						
						| 
							 | 
						        elif self.stg_mode.lower() == "stg_as" or self.stg_mode.lower() == "attention_skip": | 
					
					
						
						| 
							 | 
						            self.stg_mode = "attention_skip" | 
					
					
						
						| 
							 | 
						        elif self.stg_mode.lower() == "stg_r" or self.stg_mode.lower() == "residual": | 
					
					
						
						| 
							 | 
						            self.stg_mode = "residual" | 
					
					
						
						| 
							 | 
						        elif self.stg_mode.lower() == "stg_t" or self.stg_mode.lower() == "transformer_block": | 
					
					
						
						| 
							 | 
						            self.stg_mode = "transformer_block" | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        if self.enhance_prompt and self.prompt: | 
					
					
						
						| 
							 | 
						            prompt_word_count = len(self.prompt.split()) | 
					
					
						
						| 
							 | 
						            if prompt_word_count >= self.prompt_enhancement_words_threshold: | 
					
					
						
						| 
							 | 
						                logger.info(f"Prompt has {prompt_word_count} words, which exceeds the threshold of {self.prompt_enhancement_words_threshold}. Prompt enhancement disabled.") | 
					
					
						
						| 
							 | 
						                self.enhance_prompt = False | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						        return self | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def load_image_to_tensor_with_resize_and_crop( | 
					
					
						
						| 
							 | 
						    image_input: Union[str, bytes], | 
					
					
						
						| 
							 | 
						    target_height: int = 704, | 
					
					
						
						| 
							 | 
						    target_width: int = 1216, | 
					
					
						
						| 
							 | 
						    quality: int = 100 | 
					
					
						
						| 
							 | 
						) -> torch.Tensor: | 
					
					
						
						| 
							 | 
						    """Load and process an image into a tensor. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						    Args: | 
					
					
						
						| 
							 | 
						        image_input: Either a file path (str) or image data (bytes) | 
					
					
						
						| 
							 | 
						        target_height: Desired height of output tensor | 
					
					
						
						| 
							 | 
						        target_width: Desired width of output tensor | 
					
					
						
						| 
							 | 
						        quality: JPEG quality to use when re-encoding (to simulate lower quality images) | 
					
					
						
						| 
							 | 
						    """ | 
					
					
						
						| 
							 | 
						    from PIL import Image | 
					
					
						
						| 
							 | 
						    import io | 
					
					
						
						| 
							 | 
						    import numpy as np | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    if isinstance(image_input, str) and image_input.startswith('data:'): | 
					
					
						
						| 
							 | 
						        header, encoded = image_input.split(",", 1) | 
					
					
						
						| 
							 | 
						        image_data = base64.b64decode(encoded) | 
					
					
						
						| 
							 | 
						        image = Image.open(io.BytesIO(image_data)).convert("RGB") | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    elif isinstance(image_input, bytes): | 
					
					
						
						| 
							 | 
						        image = Image.open(io.BytesIO(image_input)).convert("RGB") | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    elif isinstance(image_input, str): | 
					
					
						
						| 
							 | 
						        image = Image.open(image_input).convert("RGB") | 
					
					
						
						| 
							 | 
						    else: | 
					
					
						
						| 
							 | 
						        raise ValueError("image_input must be either a file path, bytes, or base64 data URI") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    if quality < 100: | 
					
					
						
						| 
							 | 
						        buffer = io.BytesIO() | 
					
					
						
						| 
							 | 
						        image.save(buffer, format="JPEG", quality=quality) | 
					
					
						
						| 
							 | 
						        buffer.seek(0) | 
					
					
						
						| 
							 | 
						        image = Image.open(buffer).convert("RGB") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    input_width, input_height = image.size | 
					
					
						
						| 
							 | 
						    aspect_ratio_target = target_width / target_height | 
					
					
						
						| 
							 | 
						    aspect_ratio_frame = input_width / input_height | 
					
					
						
						| 
							 | 
						    if aspect_ratio_frame > aspect_ratio_target: | 
					
					
						
						| 
							 | 
						        new_width = int(input_height * aspect_ratio_target) | 
					
					
						
						| 
							 | 
						        new_height = input_height | 
					
					
						
						| 
							 | 
						        x_start = (input_width - new_width) // 2 | 
					
					
						
						| 
							 | 
						        y_start = 0 | 
					
					
						
						| 
							 | 
						    else: | 
					
					
						
						| 
							 | 
						        new_width = input_width | 
					
					
						
						| 
							 | 
						        new_height = int(input_width / aspect_ratio_target) | 
					
					
						
						| 
							 | 
						        x_start = 0 | 
					
					
						
						| 
							 | 
						        y_start = (input_height - new_height) // 2 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    image = image.crop((x_start, y_start, x_start + new_width, y_start + new_height)) | 
					
					
						
						| 
							 | 
						    image = image.resize((target_width, target_height)) | 
					
					
						
						| 
							 | 
						    frame_tensor = torch.tensor(np.array(image)).permute(2, 0, 1).float() | 
					
					
						
						| 
							 | 
						    frame_tensor = (frame_tensor / 127.5) - 1.0 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    return frame_tensor.unsqueeze(0).unsqueeze(2) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def calculate_padding( | 
					
					
						
						| 
							 | 
						    source_height: int, source_width: int, target_height: int, target_width: int | 
					
					
						
						| 
							 | 
						) -> tuple[int, int, int, int]: | 
					
					
						
						| 
							 | 
						    """Calculate padding to reach target dimensions""" | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    pad_height = target_height - source_height | 
					
					
						
						| 
							 | 
						    pad_width = target_width - source_width | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    pad_top = pad_height // 2 | 
					
					
						
						| 
							 | 
						    pad_bottom = pad_height - pad_top   | 
					
					
						
						| 
							 | 
						    pad_left = pad_width // 2 | 
					
					
						
						| 
							 | 
						    pad_right = pad_width - pad_left   | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    padding = (pad_left, pad_right, pad_top, pad_bottom) | 
					
					
						
						| 
							 | 
						    return padding | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def prepare_conditioning( | 
					
					
						
						| 
							 | 
						    conditioning_media_paths: List[str], | 
					
					
						
						| 
							 | 
						    conditioning_strengths: List[float], | 
					
					
						
						| 
							 | 
						    conditioning_start_frames: List[int], | 
					
					
						
						| 
							 | 
						    height: int, | 
					
					
						
						| 
							 | 
						    width: int, | 
					
					
						
						| 
							 | 
						    num_frames: int, | 
					
					
						
						| 
							 | 
						    input_image_quality: int = 100, | 
					
					
						
						| 
							 | 
						    pipeline: Optional[LTXVideoPipeline] = None, | 
					
					
						
						| 
							 | 
						) -> Optional[List[ConditioningItem]]: | 
					
					
						
						| 
							 | 
						    """Prepare conditioning items based on input media paths and their parameters""" | 
					
					
						
						| 
							 | 
						    conditioning_items = [] | 
					
					
						
						| 
							 | 
						    for path, strength, start_frame in zip( | 
					
					
						
						| 
							 | 
						        conditioning_media_paths, conditioning_strengths, conditioning_start_frames | 
					
					
						
						| 
							 | 
						    ): | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        frame_tensor = load_image_to_tensor_with_resize_and_crop( | 
					
					
						
						| 
							 | 
						            path, height, width, quality=input_image_quality | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        if pipeline: | 
					
					
						
						| 
							 | 
						            frame_count = 1   | 
					
					
						
						| 
							 | 
						            frame_count = pipeline.trim_conditioning_sequence( | 
					
					
						
						| 
							 | 
						                start_frame, frame_count, num_frames | 
					
					
						
						| 
							 | 
						            ) | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						        conditioning_items.append( | 
					
					
						
						| 
							 | 
						            ConditioningItem(frame_tensor, start_frame, strength) | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    return conditioning_items | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def create_ltx_video_pipeline( | 
					
					
						
						| 
							 | 
						    config: GenerationConfig, | 
					
					
						
						| 
							 | 
						    device: str = "cuda" | 
					
					
						
						| 
							 | 
						) -> Union[LTXVideoPipeline, LTXMultiScalePipeline]: | 
					
					
						
						| 
							 | 
						    """Create and configure the LTX video pipeline""" | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    ckpt_path = "/repository/ltxv-2b-0.9.8-distilled.safetensors" | 
					
					
						
						| 
							 | 
						    spatial_upscaler_path = "/repository/ltxv-spatial-upscaler-0.9.8.safetensors" | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    allowed_inference_steps = None | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    assert os.path.exists( | 
					
					
						
						| 
							 | 
						        ckpt_path | 
					
					
						
						| 
							 | 
						    ), f"Ckpt path provided (--ckpt_path) {ckpt_path} does not exist" | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    with safe_open(ckpt_path, framework="pt") as f: | 
					
					
						
						| 
							 | 
						        metadata = f.metadata() | 
					
					
						
						| 
							 | 
						        config_str = metadata.get("config") | 
					
					
						
						| 
							 | 
						        configs = json.loads(config_str) | 
					
					
						
						| 
							 | 
						        allowed_inference_steps = configs.get("allowed_inference_steps", None) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    vae = CausalVideoAutoencoder.from_pretrained(ckpt_path) | 
					
					
						
						| 
							 | 
						    transformer = Transformer3DModel.from_pretrained(ckpt_path) | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    if config.sampler: | 
					
					
						
						| 
							 | 
						        scheduler = RectifiedFlowScheduler( | 
					
					
						
						| 
							 | 
						            sampler=("Uniform" if config.sampler.lower() == "uniform" else "LinearQuadratic") | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						    else: | 
					
					
						
						| 
							 | 
						        scheduler = RectifiedFlowScheduler.from_pretrained(ckpt_path) | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    text_encoder = T5EncoderModel.from_pretrained("/repository/text_encoder") | 
					
					
						
						| 
							 | 
						    patchifier = SymmetricPatchifier(patch_size=1) | 
					
					
						
						| 
							 | 
						    tokenizer = T5Tokenizer.from_pretrained("/repository/tokenizer") | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    vae = vae.to(device) | 
					
					
						
						| 
							 | 
						    transformer = transformer.to(device) | 
					
					
						
						| 
							 | 
						    text_encoder = text_encoder.to(device) | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    vae = vae.to(torch.bfloat16) | 
					
					
						
						| 
							 | 
						    transformer = transformer.to(torch.bfloat16) | 
					
					
						
						| 
							 | 
						    text_encoder = text_encoder.to(torch.bfloat16) | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    prompt_enhancer_components = { | 
					
					
						
						| 
							 | 
						        "prompt_enhancer_image_caption_model": None, | 
					
					
						
						| 
							 | 
						        "prompt_enhancer_image_caption_processor": None, | 
					
					
						
						| 
							 | 
						        "prompt_enhancer_llm_model": None, | 
					
					
						
						| 
							 | 
						        "prompt_enhancer_llm_tokenizer": None | 
					
					
						
						| 
							 | 
						    } | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    if config.enhance_prompt: | 
					
					
						
						| 
							 | 
						        try: | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            prompt_enhancer_image_caption_model = AutoModelForCausalLM.from_pretrained( | 
					
					
						
						| 
							 | 
						                "MiaoshouAI/Florence-2-large-PromptGen-v2.0",  | 
					
					
						
						| 
							 | 
						                trust_remote_code=True | 
					
					
						
						| 
							 | 
						            ) | 
					
					
						
						| 
							 | 
						            prompt_enhancer_image_caption_processor = AutoProcessor.from_pretrained( | 
					
					
						
						| 
							 | 
						                "MiaoshouAI/Florence-2-large-PromptGen-v2.0",  | 
					
					
						
						| 
							 | 
						                trust_remote_code=True | 
					
					
						
						| 
							 | 
						            ) | 
					
					
						
						| 
							 | 
						            prompt_enhancer_llm_model = AutoModelForCausalLM.from_pretrained( | 
					
					
						
						| 
							 | 
						                "unsloth/Llama-3.2-3B-Instruct", | 
					
					
						
						| 
							 | 
						                torch_dtype="bfloat16", | 
					
					
						
						| 
							 | 
						            ) | 
					
					
						
						| 
							 | 
						            prompt_enhancer_llm_tokenizer = AutoTokenizer.from_pretrained( | 
					
					
						
						| 
							 | 
						                "unsloth/Llama-3.2-3B-Instruct", | 
					
					
						
						| 
							 | 
						            ) | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            prompt_enhancer_components = { | 
					
					
						
						| 
							 | 
						                "prompt_enhancer_image_caption_model": prompt_enhancer_image_caption_model, | 
					
					
						
						| 
							 | 
						                "prompt_enhancer_image_caption_processor": prompt_enhancer_image_caption_processor, | 
					
					
						
						| 
							 | 
						                "prompt_enhancer_llm_model": prompt_enhancer_llm_model, | 
					
					
						
						| 
							 | 
						                "prompt_enhancer_llm_tokenizer": prompt_enhancer_llm_tokenizer | 
					
					
						
						| 
							 | 
						            } | 
					
					
						
						| 
							 | 
						        except Exception as e: | 
					
					
						
						| 
							 | 
						            logger.warning(f"Failed to load prompt enhancer models: {e}") | 
					
					
						
						| 
							 | 
						            config.enhance_prompt = False | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    pipeline = LTXVideoPipeline( | 
					
					
						
						| 
							 | 
						        transformer=transformer, | 
					
					
						
						| 
							 | 
						        patchifier=patchifier, | 
					
					
						
						| 
							 | 
						        text_encoder=text_encoder, | 
					
					
						
						| 
							 | 
						        tokenizer=tokenizer, | 
					
					
						
						| 
							 | 
						        scheduler=scheduler, | 
					
					
						
						| 
							 | 
						        vae=vae, | 
					
					
						
						| 
							 | 
						        allowed_inference_steps=allowed_inference_steps, | 
					
					
						
						| 
							 | 
						        **prompt_enhancer_components | 
					
					
						
						| 
							 | 
						    ) | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    if config.pipeline_type == "multi-scale": | 
					
					
						
						| 
							 | 
						        if os.path.exists(spatial_upscaler_path): | 
					
					
						
						| 
							 | 
						            latent_upsampler = LatentUpsampler.from_pretrained(spatial_upscaler_path) | 
					
					
						
						| 
							 | 
						            latent_upsampler = latent_upsampler.to(device) | 
					
					
						
						| 
							 | 
						            latent_upsampler = latent_upsampler.to(torch.bfloat16) | 
					
					
						
						| 
							 | 
						            pipeline = LTXMultiScalePipeline(pipeline, latent_upsampler=latent_upsampler) | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            logger.warning(f"Spatial upscaler not found at {spatial_upscaler_path}, falling back to base pipeline") | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    return pipeline | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						class EndpointHandler: | 
					
					
						
						| 
							 | 
						    """Handler for the LTX Video endpoint""" | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    def __init__(self, model_path: str = "/repository/"): | 
					
					
						
						| 
							 | 
						        """Initialize the endpoint handler | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        Args: | 
					
					
						
						| 
							 | 
						            model_path: Path to model weights (not used, as weights are in current directory) | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        torch.backends.cuda.matmul.allow_tf32 = True | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        self.varnish = Varnish( | 
					
					
						
						| 
							 | 
						            device="cuda", | 
					
					
						
						| 
							 | 
						            model_base_dir="/repository/varnish", | 
					
					
						
						| 
							 | 
						            enable_mmaudio=False,   | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        self.pipeline = None | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        logger.info("Performing warm-up inference...") | 
					
					
						
						| 
							 | 
						        self._warmup() | 
					
					
						
						| 
							 | 
						        logger.info("Warm-up completed!") | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						    def _warmup(self): | 
					
					
						
						| 
							 | 
						        """Perform a warm-up inference to prepare the model for future requests""" | 
					
					
						
						| 
							 | 
						        try: | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            test_config = GenerationConfig( | 
					
					
						
						| 
							 | 
						                prompt="an astronaut is riding a cow in the desert, during golden hour", | 
					
					
						
						| 
							 | 
						                negative_prompt="worst quality, lowres", | 
					
					
						
						| 
							 | 
						                width=768,   | 
					
					
						
						| 
							 | 
						                height=416, | 
					
					
						
						| 
							 | 
						                num_frames=33,   | 
					
					
						
						| 
							 | 
						                guidance_scale=1.0, | 
					
					
						
						| 
							 | 
						                num_inference_steps=4,   | 
					
					
						
						| 
							 | 
						                seed=42,   | 
					
					
						
						| 
							 | 
						                fps=16,   | 
					
					
						
						| 
							 | 
						                enable_audio=False,   | 
					
					
						
						| 
							 | 
						                mixed_precision=True, | 
					
					
						
						| 
							 | 
						            ).validate_and_adjust() | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            if self.pipeline is None: | 
					
					
						
						| 
							 | 
						                self.pipeline = create_ltx_video_pipeline(test_config) | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16), torch.no_grad(): | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                random.seed(test_config.seed) | 
					
					
						
						| 
							 | 
						                np.random.seed(test_config.seed) | 
					
					
						
						| 
							 | 
						                torch.manual_seed(test_config.seed) | 
					
					
						
						| 
							 | 
						                generator = torch.Generator(device='cuda').manual_seed(test_config.seed) | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                if test_config.pipeline_type == "multi-scale" and isinstance(self.pipeline, LTXMultiScalePipeline): | 
					
					
						
						| 
							 | 
						                     | 
					
					
						
						| 
							 | 
						                    first_pass = { | 
					
					
						
						| 
							 | 
						                        "timesteps": DEFAULT_FIRST_PASS_TIMESTEPS[:4],   | 
					
					
						
						| 
							 | 
						                        "guidance_scale": 1, | 
					
					
						
						| 
							 | 
						                        "stg_scale": 0, | 
					
					
						
						| 
							 | 
						                        "rescaling_scale": 1, | 
					
					
						
						| 
							 | 
						                        "skip_block_list": [42] | 
					
					
						
						| 
							 | 
						                    } | 
					
					
						
						| 
							 | 
						                     | 
					
					
						
						| 
							 | 
						                    second_pass = { | 
					
					
						
						| 
							 | 
						                        "timesteps": DEFAULT_SECOND_PASS_TIMESTEPS[:2],   | 
					
					
						
						| 
							 | 
						                        "guidance_scale": 1, | 
					
					
						
						| 
							 | 
						                        "stg_scale": 0, | 
					
					
						
						| 
							 | 
						                        "rescaling_scale": 1, | 
					
					
						
						| 
							 | 
						                        "skip_block_list": [42] | 
					
					
						
						| 
							 | 
						                    } | 
					
					
						
						| 
							 | 
						                     | 
					
					
						
						| 
							 | 
						                    result = self.pipeline( | 
					
					
						
						| 
							 | 
						                        downscale_factor=test_config.downscale_factor, | 
					
					
						
						| 
							 | 
						                        first_pass=first_pass, | 
					
					
						
						| 
							 | 
						                        second_pass=second_pass, | 
					
					
						
						| 
							 | 
						                        height=test_config.height, | 
					
					
						
						| 
							 | 
						                        width=test_config.width, | 
					
					
						
						| 
							 | 
						                        num_frames=test_config.num_frames, | 
					
					
						
						| 
							 | 
						                        frame_rate=test_config.fps, | 
					
					
						
						| 
							 | 
						                        prompt=test_config.prompt, | 
					
					
						
						| 
							 | 
						                        negative_prompt=test_config.negative_prompt, | 
					
					
						
						| 
							 | 
						                        generator=generator, | 
					
					
						
						| 
							 | 
						                        output_type="pt", | 
					
					
						
						| 
							 | 
						                        mixed_precision=test_config.mixed_precision, | 
					
					
						
						| 
							 | 
						                        is_video=True, | 
					
					
						
						| 
							 | 
						                        vae_per_channel_normalize=True, | 
					
					
						
						| 
							 | 
						                    ) | 
					
					
						
						| 
							 | 
						                else: | 
					
					
						
						| 
							 | 
						                     | 
					
					
						
						| 
							 | 
						                    result = self.pipeline( | 
					
					
						
						| 
							 | 
						                        height=test_config.height, | 
					
					
						
						| 
							 | 
						                        width=test_config.width, | 
					
					
						
						| 
							 | 
						                        num_frames=test_config.num_frames, | 
					
					
						
						| 
							 | 
						                        frame_rate=test_config.fps, | 
					
					
						
						| 
							 | 
						                        prompt=test_config.prompt, | 
					
					
						
						| 
							 | 
						                        negative_prompt=test_config.negative_prompt, | 
					
					
						
						| 
							 | 
						                        guidance_scale=test_config.guidance_scale, | 
					
					
						
						| 
							 | 
						                        num_inference_steps=test_config.num_inference_steps, | 
					
					
						
						| 
							 | 
						                        generator=generator, | 
					
					
						
						| 
							 | 
						                        output_type="pt", | 
					
					
						
						| 
							 | 
						                        mixed_precision=test_config.mixed_precision, | 
					
					
						
						| 
							 | 
						                        is_video=True, | 
					
					
						
						| 
							 | 
						                        vae_per_channel_normalize=True, | 
					
					
						
						| 
							 | 
						                    ) | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                frames = result.images | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                del result | 
					
					
						
						| 
							 | 
						                torch.cuda.empty_cache() | 
					
					
						
						| 
							 | 
						                gc.collect() | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                logger.info(f"Warm-up successful! Generated {frames.shape[2]} frames at {frames.shape[3]}x{frames.shape[4]}") | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						        except Exception as e: | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            import traceback | 
					
					
						
						| 
							 | 
						            error_message = f"Warm-up failed (but this is non-critical): {str(e)}\n{traceback.format_exc()}" | 
					
					
						
						| 
							 | 
						            logger.warning(error_message) | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: | 
					
					
						
						| 
							 | 
						        """Process inference requests | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        Args: | 
					
					
						
						| 
							 | 
						            data: Request data containing inputs and parameters | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						        Returns: | 
					
					
						
						| 
							 | 
						            Dictionary with generated video and metadata | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        inputs = data.get("inputs", {}) | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        if isinstance(inputs, str): | 
					
					
						
						| 
							 | 
						            input_prompt = inputs | 
					
					
						
						| 
							 | 
						            input_image = None | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            input_prompt = inputs.get("prompt", "") | 
					
					
						
						| 
							 | 
						            input_image = inputs.get("image") | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        params = data.get("parameters", {}) | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        if not input_prompt and not input_image: | 
					
					
						
						| 
							 | 
						            raise ValueError("Either prompt or image must be provided") | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        config = GenerationConfig( | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            prompt=input_prompt, | 
					
					
						
						| 
							 | 
						            negative_prompt=params.get("negative_prompt", GenerationConfig.negative_prompt), | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            width=params.get("width", GenerationConfig.width), | 
					
					
						
						| 
							 | 
						            height=params.get("height", GenerationConfig.height), | 
					
					
						
						| 
							 | 
						            input_image_quality=params.get("input_image_quality", GenerationConfig.input_image_quality), | 
					
					
						
						| 
							 | 
						            num_frames=params.get("num_frames", GenerationConfig.num_frames), | 
					
					
						
						| 
							 | 
						            guidance_scale=params.get("guidance_scale", GenerationConfig.guidance_scale), | 
					
					
						
						| 
							 | 
						            num_inference_steps=params.get("num_inference_steps", GenerationConfig.num_inference_steps), | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            stg_scale=params.get("stg_scale", GenerationConfig.stg_scale), | 
					
					
						
						| 
							 | 
						            stg_rescale=params.get("stg_rescale", GenerationConfig.stg_rescale), | 
					
					
						
						| 
							 | 
						            stg_mode=params.get("stg_mode", GenerationConfig.stg_mode), | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            decode_timestep=params.get("decode_timestep", GenerationConfig.decode_timestep), | 
					
					
						
						| 
							 | 
						            decode_noise_scale=params.get("decode_noise_scale", GenerationConfig.decode_noise_scale), | 
					
					
						
						| 
							 | 
						            image_cond_noise_scale=params.get("image_cond_noise_scale", GenerationConfig.image_cond_noise_scale), | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            seed=params.get("seed", GenerationConfig.seed), | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            fps=params.get("fps", GenerationConfig.fps), | 
					
					
						
						| 
							 | 
						            double_num_frames=params.get("double_num_frames", GenerationConfig.double_num_frames), | 
					
					
						
						| 
							 | 
						            super_resolution=params.get("super_resolution", GenerationConfig.super_resolution), | 
					
					
						
						| 
							 | 
						            grain_amount=params.get("grain_amount", GenerationConfig.grain_amount), | 
					
					
						
						| 
							 | 
						            enable_audio=params.get("enable_audio", GenerationConfig.enable_audio), | 
					
					
						
						| 
							 | 
						            audio_prompt=params.get("audio_prompt", GenerationConfig.audio_prompt), | 
					
					
						
						| 
							 | 
						            audio_negative_prompt=params.get("audio_negative_prompt", GenerationConfig.audio_negative_prompt), | 
					
					
						
						| 
							 | 
						            quality=params.get("quality", GenerationConfig.quality), | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            mixed_precision=params.get("mixed_precision", GenerationConfig.mixed_precision), | 
					
					
						
						| 
							 | 
						            stochastic_sampling=params.get("stochastic_sampling", GenerationConfig.stochastic_sampling), | 
					
					
						
						| 
							 | 
						            sampler=params.get("sampler", GenerationConfig.sampler), | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            pipeline_type=params.get("pipeline_type", GenerationConfig.pipeline_type), | 
					
					
						
						| 
							 | 
						            downscale_factor=params.get("downscale_factor", GenerationConfig.downscale_factor), | 
					
					
						
						| 
							 | 
						            first_pass_timesteps=params.get("first_pass_timesteps", GenerationConfig.first_pass_timesteps), | 
					
					
						
						| 
							 | 
						            second_pass_timesteps=params.get("second_pass_timesteps", GenerationConfig.second_pass_timesteps), | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            enhance_prompt=params.get("enhance_prompt", GenerationConfig.enhance_prompt), | 
					
					
						
						| 
							 | 
						            prompt_enhancement_words_threshold=params.get( | 
					
					
						
						| 
							 | 
						                "prompt_enhancement_words_threshold",  | 
					
					
						
						| 
							 | 
						                GenerationConfig.prompt_enhancement_words_threshold | 
					
					
						
						| 
							 | 
						            ), | 
					
					
						
						| 
							 | 
						        ).validate_and_adjust() | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        try: | 
					
					
						
						| 
							 | 
						            with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16), torch.no_grad(): | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                random.seed(config.seed) | 
					
					
						
						| 
							 | 
						                np.random.seed(config.seed) | 
					
					
						
						| 
							 | 
						                torch.manual_seed(config.seed) | 
					
					
						
						| 
							 | 
						                generator = torch.Generator(device='cuda').manual_seed(config.seed) | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                if self.pipeline is None: | 
					
					
						
						| 
							 | 
						                    self.pipeline = create_ltx_video_pipeline(config) | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                conditioning_items = None | 
					
					
						
						| 
							 | 
						                if input_image: | 
					
					
						
						| 
							 | 
						                    conditioning_items = [ | 
					
					
						
						| 
							 | 
						                        ConditioningItem( | 
					
					
						
						| 
							 | 
						                            load_image_to_tensor_with_resize_and_crop( | 
					
					
						
						| 
							 | 
						                                input_image,  | 
					
					
						
						| 
							 | 
						                                config.height,  | 
					
					
						
						| 
							 | 
						                                config.width, | 
					
					
						
						| 
							 | 
						                                quality=config.input_image_quality | 
					
					
						
						| 
							 | 
						                            ), | 
					
					
						
						| 
							 | 
						                            0,   | 
					
					
						
						| 
							 | 
						                            1.0   | 
					
					
						
						| 
							 | 
						                        ) | 
					
					
						
						| 
							 | 
						                    ] | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                if config.stg_mode == "attention_values": | 
					
					
						
						| 
							 | 
						                    skip_layer_strategy = SkipLayerStrategy.AttentionValues | 
					
					
						
						| 
							 | 
						                elif config.stg_mode == "attention_skip": | 
					
					
						
						| 
							 | 
						                    skip_layer_strategy = SkipLayerStrategy.AttentionSkip | 
					
					
						
						| 
							 | 
						                elif config.stg_mode == "residual": | 
					
					
						
						| 
							 | 
						                    skip_layer_strategy = SkipLayerStrategy.Residual | 
					
					
						
						| 
							 | 
						                elif config.stg_mode == "transformer_block": | 
					
					
						
						| 
							 | 
						                    skip_layer_strategy = SkipLayerStrategy.TransformerBlock | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                if config.pipeline_type == "multi-scale" and isinstance(self.pipeline, LTXMultiScalePipeline): | 
					
					
						
						| 
							 | 
						                     | 
					
					
						
						| 
							 | 
						                     | 
					
					
						
						| 
							 | 
						                    first_pass = { | 
					
					
						
						| 
							 | 
						                        "guidance_scale": config.guidance_scale if config.guidance_scale != 3.0 else 1,   | 
					
					
						
						| 
							 | 
						                        "stg_scale": config.stg_scale, | 
					
					
						
						| 
							 | 
						                        "rescaling_scale": config.stg_rescale, | 
					
					
						
						| 
							 | 
						                        "skip_block_list": [42] | 
					
					
						
						| 
							 | 
						                    } | 
					
					
						
						| 
							 | 
						                     | 
					
					
						
						| 
							 | 
						                     | 
					
					
						
						| 
							 | 
						                    second_pass = { | 
					
					
						
						| 
							 | 
						                        "guidance_scale": config.guidance_scale if config.guidance_scale != 3.0 else 1,   | 
					
					
						
						| 
							 | 
						                        "stg_scale": config.stg_scale, | 
					
					
						
						| 
							 | 
						                        "rescaling_scale": config.stg_rescale, | 
					
					
						
						| 
							 | 
						                        "skip_block_list": [42] | 
					
					
						
						| 
							 | 
						                    } | 
					
					
						
						| 
							 | 
						                     | 
					
					
						
						| 
							 | 
						                     | 
					
					
						
						| 
							 | 
						                    if config.first_pass_timesteps is not None: | 
					
					
						
						| 
							 | 
						                         | 
					
					
						
						| 
							 | 
						                        first_pass["timesteps"] = config.first_pass_timesteps | 
					
					
						
						| 
							 | 
						                        second_pass["timesteps"] = config.second_pass_timesteps or DEFAULT_SECOND_PASS_TIMESTEPS | 
					
					
						
						| 
							 | 
						                    elif config.num_inference_steps != 8: | 
					
					
						
						| 
							 | 
						                         | 
					
					
						
						| 
							 | 
						                        if config.num_inference_steps <= 4: | 
					
					
						
						| 
							 | 
						                             | 
					
					
						
						| 
							 | 
						                            if config.num_inference_steps == 1: | 
					
					
						
						| 
							 | 
						                                first_pass["timesteps"] = [1.0] | 
					
					
						
						| 
							 | 
						                                second_pass["timesteps"] = [0.4219] | 
					
					
						
						| 
							 | 
						                            elif config.num_inference_steps == 2: | 
					
					
						
						| 
							 | 
						                                first_pass["timesteps"] = [1.0] | 
					
					
						
						| 
							 | 
						                                second_pass["timesteps"] = [0.9094, 0.4219]  | 
					
					
						
						| 
							 | 
						                            elif config.num_inference_steps == 3: | 
					
					
						
						| 
							 | 
						                                first_pass["timesteps"] = [1.0, 0.9094] | 
					
					
						
						| 
							 | 
						                                second_pass["timesteps"] = [0.9094, 0.4219] | 
					
					
						
						| 
							 | 
						                            else:   | 
					
					
						
						| 
							 | 
						                                first_pass["timesteps"] = [1.0, 0.975, 0.9094] | 
					
					
						
						| 
							 | 
						                                second_pass["timesteps"] = [0.9094, 0.725, 0.4219] | 
					
					
						
						| 
							 | 
						                        else: | 
					
					
						
						| 
							 | 
						                             | 
					
					
						
						| 
							 | 
						                            first_pass_steps = max(1, int(config.num_inference_steps * 0.7)) | 
					
					
						
						| 
							 | 
						                            second_pass_steps = max(1, config.num_inference_steps - first_pass_steps) | 
					
					
						
						| 
							 | 
						                             | 
					
					
						
						| 
							 | 
						                             | 
					
					
						
						| 
							 | 
						                            first_pass["timesteps"] = generate_valid_timesteps(first_pass_steps, ALLOWED_TIMESTEPS) | 
					
					
						
						| 
							 | 
						                             | 
					
					
						
						| 
							 | 
						                             | 
					
					
						
						| 
							 | 
						                             | 
					
					
						
						| 
							 | 
						                            start_idx = 5   | 
					
					
						
						| 
							 | 
						                            if second_pass_steps == 1: | 
					
					
						
						| 
							 | 
						                                second_pass["timesteps"] = [ALLOWED_TIMESTEPS[start_idx]]   | 
					
					
						
						| 
							 | 
						                            else: | 
					
					
						
						| 
							 | 
						                                 | 
					
					
						
						| 
							 | 
						                                end_idx = min(len(ALLOWED_TIMESTEPS), start_idx + second_pass_steps) | 
					
					
						
						| 
							 | 
						                                second_pass["timesteps"] = ALLOWED_TIMESTEPS[start_idx:end_idx] | 
					
					
						
						| 
							 | 
						                    else: | 
					
					
						
						| 
							 | 
						                         | 
					
					
						
						| 
							 | 
						                        first_pass["timesteps"] = DEFAULT_FIRST_PASS_TIMESTEPS | 
					
					
						
						| 
							 | 
						                        second_pass["timesteps"] = DEFAULT_SECOND_PASS_TIMESTEPS | 
					
					
						
						| 
							 | 
						                     | 
					
					
						
						| 
							 | 
						                    result = self.pipeline( | 
					
					
						
						| 
							 | 
						                        downscale_factor=config.downscale_factor, | 
					
					
						
						| 
							 | 
						                        first_pass=first_pass, | 
					
					
						
						| 
							 | 
						                        second_pass=second_pass, | 
					
					
						
						| 
							 | 
						                        height=config.height, | 
					
					
						
						| 
							 | 
						                        width=config.width, | 
					
					
						
						| 
							 | 
						                        num_frames=config.num_frames, | 
					
					
						
						| 
							 | 
						                        frame_rate=config.fps, | 
					
					
						
						| 
							 | 
						                        prompt=config.prompt, | 
					
					
						
						| 
							 | 
						                        negative_prompt=config.negative_prompt, | 
					
					
						
						| 
							 | 
						                        generator=generator, | 
					
					
						
						| 
							 | 
						                        output_type="pt",   | 
					
					
						
						| 
							 | 
						                        skip_layer_strategy=skip_layer_strategy, | 
					
					
						
						| 
							 | 
						                        conditioning_items=conditioning_items, | 
					
					
						
						| 
							 | 
						                        decode_timestep=config.decode_timestep, | 
					
					
						
						| 
							 | 
						                        decode_noise_scale=config.decode_noise_scale, | 
					
					
						
						| 
							 | 
						                        image_cond_noise_scale=config.image_cond_noise_scale, | 
					
					
						
						| 
							 | 
						                        mixed_precision=config.mixed_precision, | 
					
					
						
						| 
							 | 
						                        is_video=True, | 
					
					
						
						| 
							 | 
						                        vae_per_channel_normalize=True, | 
					
					
						
						| 
							 | 
						                        stochastic_sampling=config.stochastic_sampling, | 
					
					
						
						| 
							 | 
						                        enhance_prompt=config.enhance_prompt, | 
					
					
						
						| 
							 | 
						                    ) | 
					
					
						
						| 
							 | 
						                else: | 
					
					
						
						| 
							 | 
						                     | 
					
					
						
						| 
							 | 
						                    result = self.pipeline( | 
					
					
						
						| 
							 | 
						                        height=config.height, | 
					
					
						
						| 
							 | 
						                        width=config.width, | 
					
					
						
						| 
							 | 
						                        num_frames=config.num_frames, | 
					
					
						
						| 
							 | 
						                        frame_rate=config.fps, | 
					
					
						
						| 
							 | 
						                        prompt=config.prompt, | 
					
					
						
						| 
							 | 
						                        negative_prompt=config.negative_prompt, | 
					
					
						
						| 
							 | 
						                        guidance_scale=config.guidance_scale, | 
					
					
						
						| 
							 | 
						                        num_inference_steps=config.num_inference_steps, | 
					
					
						
						| 
							 | 
						                        generator=generator, | 
					
					
						
						| 
							 | 
						                        output_type="pt",   | 
					
					
						
						| 
							 | 
						                        skip_layer_strategy=skip_layer_strategy, | 
					
					
						
						| 
							 | 
						                        stg_scale=config.stg_scale, | 
					
					
						
						| 
							 | 
						                        do_rescaling=config.stg_rescale != 1.0, | 
					
					
						
						| 
							 | 
						                        rescaling_scale=config.stg_rescale, | 
					
					
						
						| 
							 | 
						                        conditioning_items=conditioning_items, | 
					
					
						
						| 
							 | 
						                        decode_timestep=config.decode_timestep, | 
					
					
						
						| 
							 | 
						                        decode_noise_scale=config.decode_noise_scale, | 
					
					
						
						| 
							 | 
						                        image_cond_noise_scale=config.image_cond_noise_scale, | 
					
					
						
						| 
							 | 
						                        mixed_precision=config.mixed_precision, | 
					
					
						
						| 
							 | 
						                        is_video=True, | 
					
					
						
						| 
							 | 
						                        vae_per_channel_normalize=True, | 
					
					
						
						| 
							 | 
						                        stochastic_sampling=config.stochastic_sampling, | 
					
					
						
						| 
							 | 
						                        enhance_prompt=config.enhance_prompt, | 
					
					
						
						| 
							 | 
						                    ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                frames = result.images | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                frames = frames.squeeze(0)   | 
					
					
						
						| 
							 | 
						                frames = frames.permute(1, 0, 2, 3)   | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                frames = frames * 255.0 | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                frames = frames.to(torch.uint8) | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                import asyncio | 
					
					
						
						| 
							 | 
						                try: | 
					
					
						
						| 
							 | 
						                    loop = asyncio.get_event_loop() | 
					
					
						
						| 
							 | 
						                except RuntimeError: | 
					
					
						
						| 
							 | 
						                    loop = asyncio.new_event_loop() | 
					
					
						
						| 
							 | 
						                    asyncio.set_event_loop(loop) | 
					
					
						
						| 
							 | 
						                     | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                varnish_result = loop.run_until_complete( | 
					
					
						
						| 
							 | 
						                    self.varnish( | 
					
					
						
						| 
							 | 
						                        frames, | 
					
					
						
						| 
							 | 
						                        fps=config.fps, | 
					
					
						
						| 
							 | 
						                        double_num_frames=config.double_num_frames, | 
					
					
						
						| 
							 | 
						                        super_resolution=config.super_resolution, | 
					
					
						
						| 
							 | 
						                        grain_amount=config.grain_amount, | 
					
					
						
						| 
							 | 
						                        enable_audio=config.enable_audio, | 
					
					
						
						| 
							 | 
						                        audio_prompt=config.audio_prompt or config.prompt, | 
					
					
						
						| 
							 | 
						                        audio_negative_prompt=config.audio_negative_prompt, | 
					
					
						
						| 
							 | 
						                    ) | 
					
					
						
						| 
							 | 
						                ) | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                video_uri = loop.run_until_complete( | 
					
					
						
						| 
							 | 
						                    varnish_result.write( | 
					
					
						
						| 
							 | 
						                        type="data-uri", | 
					
					
						
						| 
							 | 
						                        quality=config.quality | 
					
					
						
						| 
							 | 
						                    ) | 
					
					
						
						| 
							 | 
						                ) | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                metadata = { | 
					
					
						
						| 
							 | 
						                    "width": varnish_result.metadata.width, | 
					
					
						
						| 
							 | 
						                    "height": varnish_result.metadata.height, | 
					
					
						
						| 
							 | 
						                    "num_frames": varnish_result.metadata.frame_count, | 
					
					
						
						| 
							 | 
						                    "fps": varnish_result.metadata.fps, | 
					
					
						
						| 
							 | 
						                    "duration": varnish_result.metadata.duration, | 
					
					
						
						| 
							 | 
						                    "seed": config.seed, | 
					
					
						
						| 
							 | 
						                    "prompt": config.prompt, | 
					
					
						
						| 
							 | 
						                } | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                del result | 
					
					
						
						| 
							 | 
						                torch.cuda.empty_cache() | 
					
					
						
						| 
							 | 
						                gc.collect() | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                return { | 
					
					
						
						| 
							 | 
						                    "video": video_uri, | 
					
					
						
						| 
							 | 
						                    "content-type": "video/mp4", | 
					
					
						
						| 
							 | 
						                    "metadata": metadata | 
					
					
						
						| 
							 | 
						                } | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						        except Exception as e: | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            import traceback | 
					
					
						
						| 
							 | 
						            error_message = f"Error generating video: {str(e)}\n{traceback.format_exc()}" | 
					
					
						
						| 
							 | 
						            logger.error(error_message) | 
					
					
						
						| 
							 | 
						            raise RuntimeError(error_message) |