Fine-Tuning?

#2
by hanshupe - opened

Can it be fine tuned same way as Qwen 2.5 VL with sfttrainer and peft, and how is it actually related to Qwen 2.5 VL?

Trying to fine-tune with peft on model.llm gives me an error

  File "/home/.../huggingface/modules/transformers_modules/AIDC-AI/Ovis2.5-9B/.../modeling_ovis2_5.py", line 778, in forward
    inputs_embeds = self.merge_multimodal(
                    ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/.../huggingface/modules/transformers_modules/AIDC-AI/Ovis2.5-9B/.../modeling_ovis2_5.py", line 802, in merge_multimodal
    multimodal_embeds[input_ids == indicator_id] = visual_indicator_embeds[i]
    ~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: a leaf Variable that requires grad is being used in an in-place operation.

I have some issues too, if someone got it working, please let us know.

AIDC-AI org

@hanshupe @Vil Hi~ The official code already supports fine-tuning, see here:
https://github.com/AIDC-AI/Ovis?tab=readme-ov-file#model-fine-tuning

Fine-tuning Ovis2.5-9B with LoRA + Gradient Checkpointing Fix

The following code worked for me when fine-tuning AIDC-AI/Ovis2.5-9B using 4-bit quantization and LoRA, including a critical patch to make merge_multimodal compatible with gradient checkpointing and DeepSpeed ZeRO-3.

from transformers import AutoModelForCausalLM, AutoProcessor, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
import torch
from PIL import Image
import logging
from typing import Dict, Sequence
import torch.nn.functional as F
import os
import types

# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Define constants (from ovis/util/constants.py and model config)
IGNORE_ID = -100
IMAGE_TOKEN = "<image>"

# Special token IDs (from Ovis2.5 configuration)
IMAGE_TOKEN_ID = 151665
VISUAL_INDICATOR_IDS = [151666, 151667, 151668, 151669, 151670]

# ============================================================================
# 1. MODEL SETUP
# ============================================================================

logger.info("Loading model...")

# Load model with 4-bit quantization
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)

model = AutoModelForCausalLM.from_pretrained(
    "AIDC-AI/Ovis2.5-9B",
    trust_remote_code=True,
    quantization_config=bnb_config,
    device_map=torch.cuda.current_device()
)

# Use AutoProcessor
processor = AutoProcessor.from_pretrained(
    "AIDC-AI/Ovis2.5-9B",
    trust_remote_code=True
)

# Access text tokenizer
text_tokenizer = processor.tokenizer if hasattr(processor, 'tokenizer') else processor

# Add missing embedding methods required by PEFT
model.get_input_embeddings = lambda: model.llm.get_input_embeddings()
model.get_output_embeddings = lambda: model.llm.get_output_embeddings()

# Freeze vision encoder and visual tokenizer
for param in model.visual_tokenizer.parameters():
    param.requires_grad = False
for param in model.vte.parameters():
    param.requires_grad = False

logger.info("Vision encoder frozen βœ“")

# Prepare model for k-bit training
model = prepare_model_for_kbit_training(model)

# Configure LoRA (with RS-LoRA)
lora_config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=r".*llm\.model\.layers\.\d+\.(self_attn\.(q_proj|k_proj|v_proj|o_proj)|mlp\.(gate_proj|up_proj|down_proj))",
    lora_dropout=0.0,
    bias="none",
    task_type="CAUSAL_LM",
    use_rslora=True,
)

model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

# Update references after LoRA wrapping
base_model = model.base_model.model
visual_tokenizer = base_model.visual_tokenizer

# Verification of trainable parameters
logger.info("\n=== Detailed Verification ===")
vision_trainable = sum(p.numel() for p in base_model.visual_tokenizer.parameters() if p.requires_grad)
vte_trainable = sum(p.numel() for p in base_model.vte.parameters() if p.requires_grad)
llm_base_trainable = sum(p.numel() for n, p in model.llm.named_parameters() if p.requires_grad and 'lora' not in n)
llm_lora_trainable = sum(p.numel() for n, p in model.llm.named_parameters() if p.requires_grad and 'lora' in n)

logger.info(f"Visual tokenizer trainable params: {vision_trainable:,} (should be 0)")
logger.info(f"Visual embedding trainable params: {vte_trainable:,} (should be 0)")
logger.info(f"LLM base trainable params: {llm_base_trainable:,}")
logger.info(f"LLM LoRA trainable params: {llm_lora_trainable:,}")
def patched_merge_multimodal(self, input_ids, pixel_values, grid_thws):
    """
    Patched version avoiding in-place operations for gradient checkpointing compatibility.
    Original code at lines 324-343 of modeling_ovis.py
    """
    VISUAL_ATOM_ID = -300
    INDICATOR_IDS = [-301, -302, -303, -304]
    from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
    
    placeholder_token_mask = torch.lt(input_ids, 0)
    multimodal_embeds = self.get_wte()(torch.masked_fill(input_ids, placeholder_token_mask, 0))
    
    need_dummy_visual_input = pixel_values is None and (self.training or is_deepspeed_zero3_enabled())
    if need_dummy_visual_input:
        pixel_values, grid_thws = self.visual_tokenizer.get_dummy_visual_inputs()
    
    if pixel_values is not None:
        target_device = multimodal_embeds.device
        target_dtype = multimodal_embeds.dtype
        
        visual_indicator_embeds = self.vte(torch.tensor(
            list(range(self.config.visual_vocab_size - len(INDICATOR_IDS), self.config.visual_vocab_size)),
            dtype=torch.long,
            device=self.vte.weight.device
        )).to(dtype=target_dtype, device=target_device)
        
        visual_tokens = self.visual_tokenizer(pixel_values, grid_thws)
        visual_embeds = self.vte(visual_tokens).to(dtype=target_dtype, device=target_device)
        
        new_embeds = multimodal_embeds.clone()
        
        # Replace indicator embeddings
        for i, indicator_id in enumerate(INDICATOR_IDS):
            mask = (input_ids == indicator_id)
            if mask.any():
                positions = mask.nonzero(as_tuple=False)
                for pos in positions:
                    batch_idx, seq_idx = pos[0].item(), pos[1].item()
                    new_embeds[batch_idx, seq_idx] = visual_indicator_embeds[i]
        
        # Replace visual atom embeddings
        visual_atom_mask = (input_ids == VISUAL_ATOM_ID)
        if visual_atom_mask.any():
            batch_size, seq_len = input_ids.shape
            visual_idx = 0
            for b in range(batch_size):
                positions = torch.where(visual_atom_mask[b])[0]
                for pos in positions:
                    if visual_idx < visual_embeds.size(0):
                        new_embeds[b, pos] = visual_embeds[visual_idx]
                        visual_idx += 1
        
        multimodal_embeds = new_embeds
    
    if need_dummy_visual_input:
        multimodal_embeds = multimodal_embeds + visual_embeds.sum() * 0.0 + visual_indicator_embeds.sum() * 0.0
    
    return multimodal_embeds


def patched_forward(
    self,
    input_ids,
    attention_mask,
    pixel_values,
    grid_thws,
    labels,
    **kwargs
):
    """Patched forward to avoid duplicate inputs_embeds argument"""
    kwargs.pop('inputs_embeds', None)
    
    inputs_embeds = self.merge_multimodal(
        input_ids=input_ids,
        pixel_values=pixel_values,
        grid_thws=grid_thws,
    )
    return self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, **kwargs)


# Apply patches
base_model.forward = types.MethodType(patched_forward, base_model)
base_model.merge_multimodal = types.MethodType(patched_merge_multimodal, base_model)

logger.info("Applied merge_multimodal patch for gradient checkpointing βœ“")
class DataCollatorForMultimodalDataset:
    def __init__(self, text_tokenizer):
        self.text_tokenizer = text_tokenizer

    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
        input_ids, pixel_values, grid_thws, attention_mask, labels = [], [], [], [], []
        
        for instance in instances:
            input_ids.append(torch.tensor(instance["input_ids"], dtype=torch.long) 
                            if isinstance(instance["input_ids"], list) else instance["input_ids"])
            
            pv = instance["pixel_values"]
            if pv is not None:
                pixel_values.append(torch.tensor(pv) if isinstance(pv, list) else pv))
            
            gt = instance["grid_thws"]
            if gt is not None:
                grid_thws.append(torch.tensor(gt) if isinstance(gt, list) else gt)
            
            attention_mask.append(torch.tensor(instance["attention_mask"], dtype=torch.bool)
                                  if isinstance(instance["attention_mask"], list) else instance["attention_mask"])
            
            labels.append(torch.tensor(instance["labels"], dtype=torch.long)
                          if isinstance(instance["labels"], list) else instance["labels"])
        
        # Pad sequences
        input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=self.text_tokenizer.pad_token_id)
        attention_mask = torch.nn.utils.rnn.pad_sequence(attention_mask, batch_first=True, padding_value=False)
        labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_ID)
        
        pixel_values = torch.cat(pixel_values, dim=0) if pixel_values else None
        grid_thws = torch.cat(grid_thws, dim=0) if grid_thws else None
        
        # Ensure at least one padding position if no pad token exists
        if 0 not in attention_mask:
            input_ids = F.pad(input_ids, (0, 1), value=self.text_tokenizer.pad_token_id)
            attention_mask = F.pad(attention_mask, (0, 1), value=False)
            labels = F.pad(labels, (0, 1), value=IGNORE_ID)
        
        if torch.all(labels == IGNORE_ID):
            logging.warning('[DataCollatorForMultimodalDataset] All samples in the current batch are ignored.')
            
        return dict(
            input_ids=input_ids,
            pixel_values=pixel_values,
            grid_thws=grid_thws,
            attention_mask=attention_mask,
            labels=labels
        )

Sign up or log in to comment