import torch.nn as nn import torch from transformers import Qwen2_5_VLForConditionalGeneration # NEW impor from .config import Qwen2_5_VLLinearConfig from torch.nn.init import kaiming_uniform_, zeros_ import math import torch.nn.functional as F from contextlib import contextmanager class Qwen2_5_VLLinearForCausalLM(Qwen2_5_VLForConditionalGeneration): config_class = Qwen2_5_VLLinearConfig def __init__(self, config): super().__init__(config) hidden_dtype = self.get_input_embeddings().weight.dtype self.prefix_proj = nn.Sequential( nn.Linear(config.extra_feat_dim, config.hidden_size, bias=True), nn.SiLU(), # or GELU / Tanh / ReLU … # nn.Dropout(0.1), # optional ) kaiming_uniform_(self.prefix_proj[0].weight, a=math.sqrt(5)) zeros_(self.prefix_proj[0].bias) # freeze everything that came from the checkpoint for n, p in self.named_parameters(): if not n.startswith("prefix_proj"): p.requires_grad_(False) @contextmanager def _prefix_injection(self, delta): """ delta: (bs, hidden) tensor already on the correct device/dtype. We register a forward hook on the language embedding layer. """ emb_layer = self.get_input_embeddings() def hook(module, inputs, output): # output: (bs, seq, hidden) # We clone to avoid in-place ops on shared cache. out = output.clone() out[:, 0] = out[:, 0] + delta # inject prefix into first token return out handle = emb_layer.register_forward_hook(lambda m, i, o: hook(m, i, o)) try: yield finally: handle.remove() # ---------------- main forward ---------------- def forward( self, input_ids=None, prefix_feats=None, past_key_values=None, attention_mask=None, labels=None, pixel_values=None, image_grid_thw=None, rope_deltas=None, **kwargs, ): # Fast decode path (no new prefix injection; cached states already include it) if past_key_values is not None: return super().forward( input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, pixel_values=None, image_grid_thw=None, rope_deltas=None, labels=labels, # safe; seldom used in gen **kwargs, ) if prefix_feats is None: raise RuntimeError("prefix_feats required when past_key_values is None.") emb_layer = self.get_input_embeddings() dtype, device = emb_layer.weight.dtype, emb_layer.weight.device delta = self.prefix_proj(prefix_feats.to(device=device, dtype=dtype)) # (bs, hidden) # Inject prefix during this one forward call (train or first gen step) with self._prefix_injection(delta): outputs = super().forward( input_ids=input_ids, attention_mask=attention_mask, labels=None, # we'll compute loss pixel_values=pixel_values, image_grid_thw=image_grid_thw, rope_deltas=rope_deltas, **kwargs, ) # manual CE loss (shifted) loss = None if labels is not None: shift_logits = outputs.logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss = F.cross_entropy( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1), ignore_index=-100, ) return outputs.__class__(loss=loss, **{k: v for k, v in outputs.items() if k != "loss"}) # ---------------- generation helper ---------------- def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, prefix_feats=None, pixel_values=None, image_grid_thw=None, rope_deltas=None, **kwargs, ): if past_key_values is None: # first step: need full context + prefix + image return { "input_ids": input_ids, "attention_mask": attention_mask, "prefix_feats": prefix_feats, "pixel_values": pixel_values, "image_grid_thw": image_grid_thw, "rope_deltas": rope_deltas, } else: # incremental: just the new token return { "input_ids": input_ids[:, -1:], "past_key_values": past_key_values, "attention_mask": attention_mask, } # --- register the new pair with the Auto* helpers -------------------------- from transformers import AutoModelForVision2Seq # 2) tell AutoModelForVision2Seq what model class goes with that config AutoModelForVision2Seq.register( Qwen2_5_VLLinearConfig, # key Qwen2_5_VLLinearForCausalLM, # value ) # -------------------------------------------------------------------------- from transformers import AutoModelForCausalLM AutoModelForCausalLM.register( Qwen2_5_VLLinearConfig, # ← your config class Qwen2_5_VLLinearForCausalLM # ← your model class )