File size: 4,161 Bytes
6ba80c2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
from typing import Literal
import torch
from torch import nn
from transformers import (
AutoConfig,
AutoModel,
AutoModelForCausalLM,
AutoModelForMaskedLM,
DynamicCache,
)
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
)
from .backbone_custom_modeling_qwen3 import CustomQwen3ForCausalLM
try:
from torch.nn.attention.flex_attention import BlockMask
except ImportError:
BlockMask = None
AUTO_MODEL_CLS = {
"AutoModel": AutoModel,
"AutoModelForCausalLM": AutoModelForCausalLM,
"AutoModelForMaskedLM": AutoModelForMaskedLM,
}
class AutoModelFromPreTrained(nn.Module):
"""Simple wrapper class that enables using AutoModel from pre-trained."""
def __init__(
self,
automodel_cls: Literal[
"AutoModel",
"AutoModelForCausalLM",
"AutoModelForMaskedLM",
],
pretrained_model_name_or_path: str,
trust_remote_code: bool = True,
num_layers: int = -1,
keep_top_layers: bool = False,
reinit_model: bool = False,
use_causal_mask: bool = False,
**automodel_init_kwargs,
):
super().__init__()
self.use_causal_mask = use_causal_mask
if reinit_model:
auto_config = AutoConfig.from_pretrained(
pretrained_model_name_or_path,
num_hidden_layers=num_layers,
trust_remote_code=trust_remote_code,
**automodel_init_kwargs,
)
self.model = CustomQwen3ForCausalLM(auto_config)
# self.model = AUTO_MODEL_CLS[automodel_cls].from_config(auto_config)
else:
self.model = AUTO_MODEL_CLS[automodel_cls].from_pretrained(
pretrained_model_name_or_path,
trust_remote_code=trust_remote_code,
**automodel_init_kwargs,
)
num_layers = (
len(self.model.model.layers) if num_layers == -1 else num_layers
)
if keep_top_layers:
self.model.model.layers = self.model.model.layers[-num_layers:]
else:
self.model.model.layers = self.model.model.layers[:num_layers]
def forward(
self,
input_ids: torch.LongTensor,
attention_mask: torch.FloatTensor | BlockMask | None = None,
position_ids: torch.LongTensor | None = None,
cache_position: torch.LongTensor | None = None,
past_key_values: DynamicCache | None = None,
fix_cache_length: bool = False, # False for AR, True for diffusion models
return_updated_cache=False,
**kwargs,
) -> CausalLMOutputWithPast | BaseModelOutputWithPast:
prev_cache_len = None
if past_key_values is not None and fix_cache_length:
prev_cache_len = [
past_key_values[i][0].shape[-2] # type: ignore
for i in range(len(past_key_values))
]
if self.use_causal_mask:
attention_mask = None # None --> enforces use of causal mask
model_output = self.model(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
cache_position=cache_position,
past_key_values=past_key_values,
**kwargs,
)
if return_updated_cache:
return BaseModelOutputWithPast(past_key_values=model_output.past_key_values)
if (
prev_cache_len is not None
and model_output.get("past_key_values", None) is not None
):
# DynamicCache extends along sequence dimension by default;
# truncate back to original cache len
for i, cache_len in enumerate(prev_cache_len):
model_output.past_key_values.key_cache[i] = (
model_output.past_key_values.key_cache[i][..., :cache_len, :]
)
model_output.past_key_values.value_cache[i] = (
model_output.past_key_values.value_cache[i][..., :cache_len, :]
)
return model_output
|