Fine-Tuning?
#2
by
hanshupe
- opened
Can it be fine tuned same way as Qwen 2.5 VL with sfttrainer and peft, and how is it actually related to Qwen 2.5 VL?
Trying to fine-tune with peft on model.llm gives me an error
File "/home/.../huggingface/modules/transformers_modules/AIDC-AI/Ovis2.5-9B/.../modeling_ovis2_5.py", line 778, in forward
inputs_embeds = self.merge_multimodal(
^^^^^^^^^^^^^^^^^^^^^^
File "/home/.../huggingface/modules/transformers_modules/AIDC-AI/Ovis2.5-9B/.../modeling_ovis2_5.py", line 802, in merge_multimodal
multimodal_embeds[input_ids == indicator_id] = visual_indicator_embeds[i]
~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: a leaf Variable that requires grad is being used in an in-place operation.
I have some issues too, if someone got it working, please let us know.
@hanshupe
@Vil
Hi~ The official code already supports fine-tuning, see here:
https://github.com/AIDC-AI/Ovis?tab=readme-ov-file#model-fine-tuning
Fine-tuning Ovis2.5-9B with LoRA + Gradient Checkpointing Fix
The following code worked for me when fine-tuning AIDC-AI/Ovis2.5-9B using 4-bit quantization and LoRA, including a critical patch to make merge_multimodal compatible with gradient checkpointing and DeepSpeed ZeRO-3.
from transformers import AutoModelForCausalLM, AutoProcessor, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
import torch
from PIL import Image
import logging
from typing import Dict, Sequence
import torch.nn.functional as F
import os
import types
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Define constants (from ovis/util/constants.py and model config)
IGNORE_ID = -100
IMAGE_TOKEN = "<image>"
# Special token IDs (from Ovis2.5 configuration)
IMAGE_TOKEN_ID = 151665
VISUAL_INDICATOR_IDS = [151666, 151667, 151668, 151669, 151670]
# ============================================================================
# 1. MODEL SETUP
# ============================================================================
logger.info("Loading model...")
# Load model with 4-bit quantization
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
model = AutoModelForCausalLM.from_pretrained(
"AIDC-AI/Ovis2.5-9B",
trust_remote_code=True,
quantization_config=bnb_config,
device_map=torch.cuda.current_device()
)
# Use AutoProcessor
processor = AutoProcessor.from_pretrained(
"AIDC-AI/Ovis2.5-9B",
trust_remote_code=True
)
# Access text tokenizer
text_tokenizer = processor.tokenizer if hasattr(processor, 'tokenizer') else processor
# Add missing embedding methods required by PEFT
model.get_input_embeddings = lambda: model.llm.get_input_embeddings()
model.get_output_embeddings = lambda: model.llm.get_output_embeddings()
# Freeze vision encoder and visual tokenizer
for param in model.visual_tokenizer.parameters():
param.requires_grad = False
for param in model.vte.parameters():
param.requires_grad = False
logger.info("Vision encoder frozen β")
# Prepare model for k-bit training
model = prepare_model_for_kbit_training(model)
# Configure LoRA (with RS-LoRA)
lora_config = LoraConfig(
r=8,
lora_alpha=16,
target_modules=r".*llm\.model\.layers\.\d+\.(self_attn\.(q_proj|k_proj|v_proj|o_proj)|mlp\.(gate_proj|up_proj|down_proj))",
lora_dropout=0.0,
bias="none",
task_type="CAUSAL_LM",
use_rslora=True,
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
# Update references after LoRA wrapping
base_model = model.base_model.model
visual_tokenizer = base_model.visual_tokenizer
# Verification of trainable parameters
logger.info("\n=== Detailed Verification ===")
vision_trainable = sum(p.numel() for p in base_model.visual_tokenizer.parameters() if p.requires_grad)
vte_trainable = sum(p.numel() for p in base_model.vte.parameters() if p.requires_grad)
llm_base_trainable = sum(p.numel() for n, p in model.llm.named_parameters() if p.requires_grad and 'lora' not in n)
llm_lora_trainable = sum(p.numel() for n, p in model.llm.named_parameters() if p.requires_grad and 'lora' in n)
logger.info(f"Visual tokenizer trainable params: {vision_trainable:,} (should be 0)")
logger.info(f"Visual embedding trainable params: {vte_trainable:,} (should be 0)")
logger.info(f"LLM base trainable params: {llm_base_trainable:,}")
logger.info(f"LLM LoRA trainable params: {llm_lora_trainable:,}")
def patched_merge_multimodal(self, input_ids, pixel_values, grid_thws):
"""
Patched version avoiding in-place operations for gradient checkpointing compatibility.
Original code at lines 324-343 of modeling_ovis.py
"""
VISUAL_ATOM_ID = -300
INDICATOR_IDS = [-301, -302, -303, -304]
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
placeholder_token_mask = torch.lt(input_ids, 0)
multimodal_embeds = self.get_wte()(torch.masked_fill(input_ids, placeholder_token_mask, 0))
need_dummy_visual_input = pixel_values is None and (self.training or is_deepspeed_zero3_enabled())
if need_dummy_visual_input:
pixel_values, grid_thws = self.visual_tokenizer.get_dummy_visual_inputs()
if pixel_values is not None:
target_device = multimodal_embeds.device
target_dtype = multimodal_embeds.dtype
visual_indicator_embeds = self.vte(torch.tensor(
list(range(self.config.visual_vocab_size - len(INDICATOR_IDS), self.config.visual_vocab_size)),
dtype=torch.long,
device=self.vte.weight.device
)).to(dtype=target_dtype, device=target_device)
visual_tokens = self.visual_tokenizer(pixel_values, grid_thws)
visual_embeds = self.vte(visual_tokens).to(dtype=target_dtype, device=target_device)
new_embeds = multimodal_embeds.clone()
# Replace indicator embeddings
for i, indicator_id in enumerate(INDICATOR_IDS):
mask = (input_ids == indicator_id)
if mask.any():
positions = mask.nonzero(as_tuple=False)
for pos in positions:
batch_idx, seq_idx = pos[0].item(), pos[1].item()
new_embeds[batch_idx, seq_idx] = visual_indicator_embeds[i]
# Replace visual atom embeddings
visual_atom_mask = (input_ids == VISUAL_ATOM_ID)
if visual_atom_mask.any():
batch_size, seq_len = input_ids.shape
visual_idx = 0
for b in range(batch_size):
positions = torch.where(visual_atom_mask[b])[0]
for pos in positions:
if visual_idx < visual_embeds.size(0):
new_embeds[b, pos] = visual_embeds[visual_idx]
visual_idx += 1
multimodal_embeds = new_embeds
if need_dummy_visual_input:
multimodal_embeds = multimodal_embeds + visual_embeds.sum() * 0.0 + visual_indicator_embeds.sum() * 0.0
return multimodal_embeds
def patched_forward(
self,
input_ids,
attention_mask,
pixel_values,
grid_thws,
labels,
**kwargs
):
"""Patched forward to avoid duplicate inputs_embeds argument"""
kwargs.pop('inputs_embeds', None)
inputs_embeds = self.merge_multimodal(
input_ids=input_ids,
pixel_values=pixel_values,
grid_thws=grid_thws,
)
return self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, **kwargs)
# Apply patches
base_model.forward = types.MethodType(patched_forward, base_model)
base_model.merge_multimodal = types.MethodType(patched_merge_multimodal, base_model)
logger.info("Applied merge_multimodal patch for gradient checkpointing β")
class DataCollatorForMultimodalDataset:
def __init__(self, text_tokenizer):
self.text_tokenizer = text_tokenizer
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
input_ids, pixel_values, grid_thws, attention_mask, labels = [], [], [], [], []
for instance in instances:
input_ids.append(torch.tensor(instance["input_ids"], dtype=torch.long)
if isinstance(instance["input_ids"], list) else instance["input_ids"])
pv = instance["pixel_values"]
if pv is not None:
pixel_values.append(torch.tensor(pv) if isinstance(pv, list) else pv))
gt = instance["grid_thws"]
if gt is not None:
grid_thws.append(torch.tensor(gt) if isinstance(gt, list) else gt)
attention_mask.append(torch.tensor(instance["attention_mask"], dtype=torch.bool)
if isinstance(instance["attention_mask"], list) else instance["attention_mask"])
labels.append(torch.tensor(instance["labels"], dtype=torch.long)
if isinstance(instance["labels"], list) else instance["labels"])
# Pad sequences
input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=self.text_tokenizer.pad_token_id)
attention_mask = torch.nn.utils.rnn.pad_sequence(attention_mask, batch_first=True, padding_value=False)
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_ID)
pixel_values = torch.cat(pixel_values, dim=0) if pixel_values else None
grid_thws = torch.cat(grid_thws, dim=0) if grid_thws else None
# Ensure at least one padding position if no pad token exists
if 0 not in attention_mask:
input_ids = F.pad(input_ids, (0, 1), value=self.text_tokenizer.pad_token_id)
attention_mask = F.pad(attention_mask, (0, 1), value=False)
labels = F.pad(labels, (0, 1), value=IGNORE_ID)
if torch.all(labels == IGNORE_ID):
logging.warning('[DataCollatorForMultimodalDataset] All samples in the current batch are ignored.')
return dict(
input_ids=input_ids,
pixel_values=pixel_values,
grid_thws=grid_thws,
attention_mask=attention_mask,
labels=labels
)