Feature Extraction
Transformers
PyTorch
e2d2
custom_code
File size: 4,161 Bytes
6ba80c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
from typing import Literal

import torch
from torch import nn
from transformers import (
    AutoConfig,
    AutoModel,
    AutoModelForCausalLM,
    AutoModelForMaskedLM,
    DynamicCache,
)
from transformers.modeling_outputs import (
    BaseModelOutputWithPast,
    CausalLMOutputWithPast,
)

from .backbone_custom_modeling_qwen3 import CustomQwen3ForCausalLM

try:
    from torch.nn.attention.flex_attention import BlockMask
except ImportError:
    BlockMask = None

AUTO_MODEL_CLS = {
    "AutoModel": AutoModel,
    "AutoModelForCausalLM": AutoModelForCausalLM,
    "AutoModelForMaskedLM": AutoModelForMaskedLM,
}


class AutoModelFromPreTrained(nn.Module):
    """Simple wrapper class that enables using AutoModel from pre-trained."""

    def __init__(
        self,
        automodel_cls: Literal[
            "AutoModel",
            "AutoModelForCausalLM",
            "AutoModelForMaskedLM",
        ],
        pretrained_model_name_or_path: str,
        trust_remote_code: bool = True,
        num_layers: int = -1,
        keep_top_layers: bool = False,
        reinit_model: bool = False,
        use_causal_mask: bool = False,
        **automodel_init_kwargs,
    ):
        super().__init__()
        self.use_causal_mask = use_causal_mask
        if reinit_model:
            auto_config = AutoConfig.from_pretrained(
                pretrained_model_name_or_path,
                num_hidden_layers=num_layers,
                trust_remote_code=trust_remote_code,
                **automodel_init_kwargs,
            )
            self.model = CustomQwen3ForCausalLM(auto_config)
            # self.model = AUTO_MODEL_CLS[automodel_cls].from_config(auto_config)
        else:
            self.model = AUTO_MODEL_CLS[automodel_cls].from_pretrained(
                pretrained_model_name_or_path,
                trust_remote_code=trust_remote_code,
                **automodel_init_kwargs,
            )
            num_layers = (
                len(self.model.model.layers) if num_layers == -1 else num_layers
            )
            if keep_top_layers:
                self.model.model.layers = self.model.model.layers[-num_layers:]
            else:
                self.model.model.layers = self.model.model.layers[:num_layers]

    def forward(
        self,
        input_ids: torch.LongTensor,
        attention_mask: torch.FloatTensor | BlockMask | None = None,
        position_ids: torch.LongTensor | None = None,
        cache_position: torch.LongTensor | None = None,
        past_key_values: DynamicCache | None = None,
        fix_cache_length: bool = False,  # False for AR, True for diffusion models
        return_updated_cache=False,
        **kwargs,
    ) -> CausalLMOutputWithPast | BaseModelOutputWithPast:
        prev_cache_len = None
        if past_key_values is not None and fix_cache_length:
            prev_cache_len = [
                past_key_values[i][0].shape[-2]  # type: ignore
                for i in range(len(past_key_values))
            ]
        if self.use_causal_mask:
            attention_mask = None  # None --> enforces use of causal mask
        model_output = self.model(
            input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            cache_position=cache_position,
            past_key_values=past_key_values,
            **kwargs,
        )
        if return_updated_cache:
            return BaseModelOutputWithPast(past_key_values=model_output.past_key_values)
        if (
            prev_cache_len is not None
            and model_output.get("past_key_values", None) is not None
        ):
            # DynamicCache extends along sequence dimension by default;
            # truncate back to original cache len
            for i, cache_len in enumerate(prev_cache_len):
                model_output.past_key_values.key_cache[i] = (
                    model_output.past_key_values.key_cache[i][..., :cache_len, :]
                )
                model_output.past_key_values.value_cache[i] = (
                    model_output.past_key_values.value_cache[i][..., :cache_len, :]
                )
        return model_output