from diffusers.models.attention_processor import FluxAttnProcessor2_0 from safetensors.torch import load_file import re import torch from .layers_cache import MultiDoubleStreamBlockLoraProcessor, MultiSingleStreamBlockLoraProcessor device = "cuda" def load_safetensors(path): """Safely loads tensors from a file and maps them to the CPU.""" return load_file(path, device="cpu") def get_lora_count_from_checkpoint(checkpoint): """ Infers the number of LoRA modules stored in a checkpoint by inspecting its keys. Also prints a sample of keys for debugging. """ lora_indices = set() # Regex to find '..._loras.X.' where X is a number. indexed_pattern = re.compile(r'._loras\.(\d+)\.') found_keys = [] for key in checkpoint.keys(): match = indexed_pattern.search(key) if match: lora_indices.add(int(match.group(1))) if len(found_keys) < 5 and key not in found_keys: found_keys.append(key) if lora_indices: lora_count = max(lora_indices) + 1 print("INFO: Auto-detected indexed LoRA keys in checkpoint.") print(f" Found {lora_count} LoRA module(s).") print(" Sample keys:", found_keys) return lora_count # Fallback for legacy, non-indexed checkpoints. legacy_found = False legacy_key_sample = "" for key in checkpoint.keys(): if '.q_lora.' in key: legacy_found = True legacy_key_sample = key break if legacy_found: print("INFO: Auto-detected legacy (non-indexed) LoRA keys in checkpoint.") print(" Assuming 1 LoRA module.") print(" Sample key:", legacy_key_sample) return 1 print("WARNING: No LoRA keys found in the checkpoint.") return 0 def get_lora_ranks(checkpoint, num_loras): """ Determines the rank for each LoRA module from the checkpoint. It supports both indexed (e.g., 'loras.0') and legacy non-indexed formats. """ ranks = {} # First, try to find ranks for all indexed LoRA modules. for i in range(num_loras): # Find a key that uniquely identifies the i-th LoRA's down projection. rank_pattern = re.compile(f'._loras\.({i})\.down\.weight') for k, v in checkpoint.items(): if rank_pattern.search(k): ranks[i] = v.shape[0] break # If not all ranks were found, there might be legacy keys or a mismatch. if len(ranks) != num_loras: # Fallback for single, non-indexed LoRA checkpoints. if num_loras == 1: for k, v in checkpoint.items(): if ".q_lora.down.weight" in k: return [v.shape[0]] # If still unresolved, use the rank of the very first LoRA found as a default for all. first_found_rank = next((v.shape[0] for k, v in checkpoint.items() if k.endswith(".down.weight")), None) if first_found_rank is None: raise ValueError("Could not determine any LoRA rank from the provided checkpoint.") # Return a list where missing ranks are filled with the first one found. return [ranks.get(i, first_found_rank) for i in range(num_loras)] # Return the list of ranks sorted by LoRA index. return [ranks[i] for i in range(num_loras)] def load_checkpoint(local_path): if local_path is not None: if '.safetensors' in local_path: print(f"Loading .safetensors checkpoint from {local_path}") checkpoint = load_safetensors(local_path) else: print(f"Loading checkpoint from {local_path}") checkpoint = torch.load(local_path, map_location='cpu') return checkpoint def prepare_lora_processors(checkpoint, lora_weights, transformer, cond_size, number=None): # Ensure processors match the transformer's device and dtype try: first_param = next(transformer.parameters()) target_device = first_param.device target_dtype = first_param.dtype except StopIteration: target_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") target_dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 if number is None: number = get_lora_count_from_checkpoint(checkpoint) if number == 0: return {} if lora_weights and len(lora_weights) != number: print(f"WARNING: Provided `lora_weights` length ({len(lora_weights)}) differs from detected LoRA count ({number}).") final_weights = (lora_weights + [1.0] * number)[:number] print(f" Adjusting weights to: {final_weights}") lora_weights = final_weights elif not lora_weights: print(f"INFO: No `lora_weights` provided. Defaulting to weights of 1.0 for all {number} LoRAs.") lora_weights = [1.0] * number ranks = get_lora_ranks(checkpoint, number) print("INFO: Determined ranks for LoRA modules:", ranks) cond_widths = cond_size if isinstance(cond_size, list) else [cond_size] * number cond_heights = cond_size if isinstance(cond_size, list) else [cond_size] * number lora_attn_procs = {} double_blocks_idx = list(range(19)) single_blocks_idx = list(range(38)) # Get all attention processor names from the transformer to iterate over for name in transformer.attn_processors.keys(): match = re.search(r'\.(\d+)\.', name) if not match: continue layer_index = int(match.group(1)) if name.startswith("transformer_blocks") and layer_index in double_blocks_idx: lora_state_dicts = { key: value for key, value in checkpoint.items() if f"transformer_blocks.{layer_index}." in key } lora_attn_procs[name] = MultiDoubleStreamBlockLoraProcessor( dim=3072, ranks=ranks, network_alphas=ranks, lora_weights=lora_weights, device=target_device, dtype=target_dtype, cond_widths=cond_widths, cond_heights=cond_heights, n_loras=number ) for n in range(number): lora_prefix_q = f"{name}.q_loras.{n}" lora_prefix_k = f"{name}.k_loras.{n}" lora_prefix_v = f"{name}.v_loras.{n}" lora_prefix_proj = f"{name}.proj_loras.{n}" lora_attn_procs[name].q_loras[n].down.weight.data = lora_state_dicts.get(f'{lora_prefix_q}.down.weight') lora_attn_procs[name].q_loras[n].up.weight.data = lora_state_dicts.get(f'{lora_prefix_q}.up.weight') lora_attn_procs[name].k_loras[n].down.weight.data = lora_state_dicts.get(f'{lora_prefix_k}.down.weight') lora_attn_procs[name].k_loras[n].up.weight.data = lora_state_dicts.get(f'{lora_prefix_k}.up.weight') lora_attn_procs[name].v_loras[n].down.weight.data = lora_state_dicts.get(f'{lora_prefix_v}.down.weight') lora_attn_procs[name].v_loras[n].up.weight.data = lora_state_dicts.get(f'{lora_prefix_v}.up.weight') lora_attn_procs[name].proj_loras[n].down.weight.data = lora_state_dicts.get(f'{lora_prefix_proj}.down.weight') lora_attn_procs[name].proj_loras[n].up.weight.data = lora_state_dicts.get(f'{lora_prefix_proj}.up.weight') lora_attn_procs[name].to(device=target_device, dtype=target_dtype) elif name.startswith("single_transformer_blocks") and layer_index in single_blocks_idx: lora_state_dicts = { key: value for key, value in checkpoint.items() if f"single_transformer_blocks.{layer_index}." in key } lora_attn_procs[name] = MultiSingleStreamBlockLoraProcessor( dim=3072, ranks=ranks, network_alphas=ranks, lora_weights=lora_weights, device=target_device, dtype=target_dtype, cond_widths=cond_widths, cond_heights=cond_heights, n_loras=number ) for n in range(number): lora_prefix_q = f"{name}.q_loras.{n}" lora_prefix_k = f"{name}.k_loras.{n}" lora_prefix_v = f"{name}.v_loras.{n}" lora_attn_procs[name].q_loras[n].down.weight.data = lora_state_dicts.get(f'{lora_prefix_q}.down.weight') lora_attn_procs[name].q_loras[n].up.weight.data = lora_state_dicts.get(f'{lora_prefix_q}.up.weight') lora_attn_procs[name].k_loras[n].down.weight.data = lora_state_dicts.get(f'{lora_prefix_k}.down.weight') lora_attn_procs[name].k_loras[n].up.weight.data = lora_state_dicts.get(f'{lora_prefix_k}.up.weight') lora_attn_procs[name].v_loras[n].down.weight.data = lora_state_dicts.get(f'{lora_prefix_v}.down.weight') lora_attn_procs[name].v_loras[n].up.weight.data = lora_state_dicts.get(f'{lora_prefix_v}.up.weight') lora_attn_procs[name].to(device=target_device, dtype=target_dtype) return lora_attn_procs