Upload 3 files
Browse files- model_minimind.py +39 -15
- model_vlm.py +1 -0
- tokenizer_config.json +4 -5
model_minimind.py
CHANGED
|
@@ -23,6 +23,7 @@ class MiniMindConfig(PretrainedConfig):
|
|
| 23 |
vocab_size: int = 6400,
|
| 24 |
rms_norm_eps: float = 1e-05,
|
| 25 |
rope_theta: int = 1000000.0,
|
|
|
|
| 26 |
flash_attn: bool = True,
|
| 27 |
####################################################
|
| 28 |
# Here are the specific configurations of MOE
|
|
@@ -52,6 +53,15 @@ class MiniMindConfig(PretrainedConfig):
|
|
| 52 |
self.vocab_size = vocab_size
|
| 53 |
self.rms_norm_eps = rms_norm_eps
|
| 54 |
self.rope_theta = rope_theta
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
self.flash_attn = flash_attn
|
| 56 |
####################################################
|
| 57 |
# Here are the specific configurations of MOE
|
|
@@ -73,10 +83,11 @@ class MiniMindConfig(PretrainedConfig):
|
|
| 73 |
|
| 74 |
import math
|
| 75 |
import torch
|
|
|
|
|
|
|
| 76 |
from torch import nn
|
| 77 |
from transformers.activations import ACT2FN
|
| 78 |
from typing import Optional, Tuple, List, Union
|
| 79 |
-
import torch.nn.functional as F
|
| 80 |
from transformers import PreTrainedModel, GenerationMixin, PretrainedConfig
|
| 81 |
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 82 |
|
|
@@ -94,8 +105,22 @@ class RMSNorm(torch.nn.Module):
|
|
| 94 |
return self.weight * self._norm(x.float()).type_as(x)
|
| 95 |
|
| 96 |
|
| 97 |
-
def precompute_freqs_cis(dim: int, end: int = int(32 * 1024),
|
| 98 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
t = torch.arange(end, device=freqs.device)
|
| 100 |
freqs = torch.outer(t, freqs).float()
|
| 101 |
freqs_cos = torch.cat([torch.cos(freqs), torch.cos(freqs)], dim=-1)
|
|
@@ -118,9 +143,7 @@ def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
|
|
| 118 |
if n_rep == 1:
|
| 119 |
return x
|
| 120 |
return (
|
| 121 |
-
x[:, :, :, None, :]
|
| 122 |
-
.expand(bs, slen, num_key_value_heads, n_rep, head_dim)
|
| 123 |
-
.reshape(bs, slen, num_key_value_heads * n_rep, head_dim)
|
| 124 |
)
|
| 125 |
|
| 126 |
|
|
@@ -170,14 +193,14 @@ class Attention(nn.Module):
|
|
| 170 |
repeat_kv(xv, self.n_rep).transpose(1, 2)
|
| 171 |
)
|
| 172 |
|
| 173 |
-
if
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
|
| 180 |
-
output = F.scaled_dot_product_attention(xq, xk, xv, attn_mask=attn_mask, dropout_p=
|
| 181 |
else:
|
| 182 |
scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
| 183 |
scores = scores + torch.triu(
|
|
@@ -232,7 +255,6 @@ class MoEGate(nn.Module):
|
|
| 232 |
self.reset_parameters()
|
| 233 |
|
| 234 |
def reset_parameters(self) -> None:
|
| 235 |
-
import torch.nn.init as init
|
| 236 |
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
| 237 |
|
| 238 |
def forward(self, hidden_states):
|
|
@@ -369,7 +391,8 @@ class MiniMindModel(nn.Module):
|
|
| 369 |
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 370 |
|
| 371 |
freqs_cos, freqs_sin = precompute_freqs_cis(dim=config.hidden_size // config.num_attention_heads,
|
| 372 |
-
end=config.max_position_embeddings,
|
|
|
|
| 373 |
self.register_buffer("freqs_cos", freqs_cos, persistent=False)
|
| 374 |
self.register_buffer("freqs_sin", freqs_sin, persistent=False)
|
| 375 |
|
|
@@ -380,6 +403,7 @@ class MiniMindModel(nn.Module):
|
|
| 380 |
use_cache: bool = False,
|
| 381 |
**kwargs):
|
| 382 |
batch_size, seq_length = input_ids.shape
|
|
|
|
| 383 |
past_key_values = past_key_values or [None] * len(self.layers)
|
| 384 |
start_pos = past_key_values[0][0].shape[1] if past_key_values[0] is not None else 0
|
| 385 |
|
|
|
|
| 23 |
vocab_size: int = 6400,
|
| 24 |
rms_norm_eps: float = 1e-05,
|
| 25 |
rope_theta: int = 1000000.0,
|
| 26 |
+
inference_rope_scaling: bool = False,
|
| 27 |
flash_attn: bool = True,
|
| 28 |
####################################################
|
| 29 |
# Here are the specific configurations of MOE
|
|
|
|
| 53 |
self.vocab_size = vocab_size
|
| 54 |
self.rms_norm_eps = rms_norm_eps
|
| 55 |
self.rope_theta = rope_theta
|
| 56 |
+
self.inference_rope_scaling = inference_rope_scaling
|
| 57 |
+
# 澶栨帹闀垮害 = factor * original_max_position_embeddings
|
| 58 |
+
self.rope_scaling = {
|
| 59 |
+
"beta_fast": 4,
|
| 60 |
+
"beta_slow": 1,
|
| 61 |
+
"factor": 4,
|
| 62 |
+
"original_max_position_embeddings": 2048,
|
| 63 |
+
"type": "yarn"
|
| 64 |
+
} if self.inference_rope_scaling else None
|
| 65 |
self.flash_attn = flash_attn
|
| 66 |
####################################################
|
| 67 |
# Here are the specific configurations of MOE
|
|
|
|
| 83 |
|
| 84 |
import math
|
| 85 |
import torch
|
| 86 |
+
import torch.nn.init as init
|
| 87 |
+
import torch.nn.functional as F
|
| 88 |
from torch import nn
|
| 89 |
from transformers.activations import ACT2FN
|
| 90 |
from typing import Optional, Tuple, List, Union
|
|
|
|
| 91 |
from transformers import PreTrainedModel, GenerationMixin, PretrainedConfig
|
| 92 |
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 93 |
|
|
|
|
| 105 |
return self.weight * self._norm(x.float()).type_as(x)
|
| 106 |
|
| 107 |
|
| 108 |
+
def precompute_freqs_cis(dim: int, end: int = int(32 * 1024), rope_base: float = 1e6,
|
| 109 |
+
rope_scaling: Optional[dict] = None):
|
| 110 |
+
freqs = 1.0 / (rope_base ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
| 111 |
+
if rope_scaling is not None:
|
| 112 |
+
orig_max, factor, beta_fast, beta_slow = (
|
| 113 |
+
rope_scaling.get("original_max_position_embeddings", 2048), rope_scaling.get("factor", 4),
|
| 114 |
+
rope_scaling.get("beta_fast", 4.0), rope_scaling.get("beta_slow", 1.0)
|
| 115 |
+
)
|
| 116 |
+
if end / orig_max > 1.0:
|
| 117 |
+
corr_dim = next((i for i in range(dim // 2) if 2 * math.pi / freqs[i] > orig_max), dim // 2)
|
| 118 |
+
power = torch.arange(0, dim // 2, device=freqs.device).float() / max(dim // 2 - 1, 1)
|
| 119 |
+
beta = beta_slow + (beta_fast - beta_slow) * power
|
| 120 |
+
# 位 = (尾路伪 - 尾 + 1)/(尾路伪) YaRN鏍囧噯鍏紡
|
| 121 |
+
scale = torch.where(torch.arange(dim // 2, device=freqs.device) < corr_dim, (beta * factor - beta + 1) / (beta * factor), 1.0 / factor)
|
| 122 |
+
freqs = freqs * scale
|
| 123 |
+
|
| 124 |
t = torch.arange(end, device=freqs.device)
|
| 125 |
freqs = torch.outer(t, freqs).float()
|
| 126 |
freqs_cos = torch.cat([torch.cos(freqs), torch.cos(freqs)], dim=-1)
|
|
|
|
| 143 |
if n_rep == 1:
|
| 144 |
return x
|
| 145 |
return (
|
| 146 |
+
x[:, :, :, None, :].expand(bs, slen, num_key_value_heads, n_rep, head_dim).reshape(bs, slen, num_key_value_heads * n_rep, head_dim)
|
|
|
|
|
|
|
| 147 |
)
|
| 148 |
|
| 149 |
|
|
|
|
| 193 |
repeat_kv(xv, self.n_rep).transpose(1, 2)
|
| 194 |
)
|
| 195 |
|
| 196 |
+
if self.flash and seq_len > 1 and (attention_mask is None or torch.all(attention_mask == 1)):
|
| 197 |
+
attn_mask = (
|
| 198 |
+
None
|
| 199 |
+
if attention_mask is None
|
| 200 |
+
else attention_mask.view(bsz, 1, 1, -1).expand(bsz, self.n_local_heads, seq_len, -1).bool()
|
| 201 |
+
)
|
| 202 |
|
| 203 |
+
output = F.scaled_dot_product_attention(xq, xk, xv, attn_mask=attn_mask, dropout_p=self.dropout if self.training else 0.0, is_causal=True)
|
| 204 |
else:
|
| 205 |
scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
| 206 |
scores = scores + torch.triu(
|
|
|
|
| 255 |
self.reset_parameters()
|
| 256 |
|
| 257 |
def reset_parameters(self) -> None:
|
|
|
|
| 258 |
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
| 259 |
|
| 260 |
def forward(self, hidden_states):
|
|
|
|
| 391 |
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 392 |
|
| 393 |
freqs_cos, freqs_sin = precompute_freqs_cis(dim=config.hidden_size // config.num_attention_heads,
|
| 394 |
+
end=config.max_position_embeddings, rope_base=config.rope_theta,
|
| 395 |
+
rope_scaling=config.rope_scaling)
|
| 396 |
self.register_buffer("freqs_cos", freqs_cos, persistent=False)
|
| 397 |
self.register_buffer("freqs_sin", freqs_sin, persistent=False)
|
| 398 |
|
|
|
|
| 403 |
use_cache: bool = False,
|
| 404 |
**kwargs):
|
| 405 |
batch_size, seq_length = input_ids.shape
|
| 406 |
+
if hasattr(past_key_values, 'layers'): past_key_values = None
|
| 407 |
past_key_values = past_key_values or [None] * len(self.layers)
|
| 408 |
start_pos = past_key_values[0][0].shape[1] if past_key_values[0] is not None else 0
|
| 409 |
|
model_vlm.py
CHANGED
|
@@ -119,6 +119,7 @@ class MiniMindVLM(MiniMindForCausalLM):
|
|
| 119 |
pixel_values: Optional[torch.FloatTensor] = None,
|
| 120 |
**args):
|
| 121 |
batch_size, seq_length = input_ids.shape
|
|
|
|
| 122 |
past_key_values = past_key_values or [None] * len(self.model.layers)
|
| 123 |
start_pos = past_key_values[0][0].shape[1] if past_key_values[0] is not None else 0
|
| 124 |
|
|
|
|
| 119 |
pixel_values: Optional[torch.FloatTensor] = None,
|
| 120 |
**args):
|
| 121 |
batch_size, seq_length = input_ids.shape
|
| 122 |
+
if hasattr(past_key_values, 'layers'): past_key_values = None
|
| 123 |
past_key_values = past_key_values or [None] * len(self.model.layers)
|
| 124 |
start_pos = past_key_values[0][0].shape[1] if past_key_values[0] is not None else 0
|
| 125 |
|
tokenizer_config.json
CHANGED
|
@@ -30,15 +30,14 @@
|
|
| 30 |
},
|
| 31 |
"additional_special_tokens": [],
|
| 32 |
"bos_token": "<|im_start|>",
|
| 33 |
-
"chat_template": "{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{{ '<|im_start|>system\\n' + system_message + '<|im_end|>\\n' }}{% else %}{{ '<|im_start|>system\\nYou are a helpful assistant<|im_end|>\\n' }}{% endif %}{% for message in messages %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|im_start|>user\\n' + content + '<|im_end|>\\n<|im_start|>assistant\\n' }}{% elif message['role'] == 'assistant' %}{{ content + '<|im_end|>' + '\\n' }}{% endif %}{% endfor %}",
|
| 34 |
"clean_up_tokenization_spaces": false,
|
| 35 |
"eos_token": "<|im_end|>",
|
| 36 |
-
"extra_special_tokens": {},
|
| 37 |
"legacy": true,
|
| 38 |
"model_max_length": 32768,
|
| 39 |
"pad_token": "<|endoftext|>",
|
| 40 |
"sp_model_kwargs": {},
|
| 41 |
"spaces_between_special_tokens": false,
|
| 42 |
-
"tokenizer_class": "
|
| 43 |
-
"unk_token": "<|endoftext|>"
|
| 44 |
-
}
|
|
|
|
|
|
| 30 |
},
|
| 31 |
"additional_special_tokens": [],
|
| 32 |
"bos_token": "<|im_start|>",
|
|
|
|
| 33 |
"clean_up_tokenization_spaces": false,
|
| 34 |
"eos_token": "<|im_end|>",
|
|
|
|
| 35 |
"legacy": true,
|
| 36 |
"model_max_length": 32768,
|
| 37 |
"pad_token": "<|endoftext|>",
|
| 38 |
"sp_model_kwargs": {},
|
| 39 |
"spaces_between_special_tokens": false,
|
| 40 |
+
"tokenizer_class": "PreTrainedTokenizerFast",
|
| 41 |
+
"unk_token": "<|endoftext|>",
|
| 42 |
+
"chat_template": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0].role == 'system' %}\n {{- messages[0].content + '\\n\\n' }}\n {%- endif %}\n {{- \"# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' -%}\n {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n {%- else -%}\n {{- '<|im_start|>system\\nYou are a helpful assistant<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}\n{%- for message in messages[::-1] %}\n {%- set index = (messages|length - 1) - loop.index0 %}\n {%- if ns.multi_step_tool and message.role == \"user\" and message.content is string and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}\n {%- set ns.multi_step_tool = false %}\n {%- set ns.last_query_index = index %}\n {%- endif %}\n{%- endfor %}\n{%- for message in messages %}\n {%- if message.content is string %}\n {%- set content = message.content %}\n {%- else %}\n {%- set content = '' %}\n {%- endif %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) %}\n {{- '<|im_start|>' + message.role + '\\n' + content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- if message.tool_calls %}\n {%- for tool_call in message.tool_calls %}\n {%- if (loop.first and content) or (not loop.first) %}\n {{- '\\n' }}\n {%- endif %}\n {%- if tool_call.function %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {%- if tool_call.arguments is string %}\n {{- tool_call.arguments }}\n {%- else %}\n {{- tool_call.arguments | tojson }}\n {%- endif %}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if loop.first or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {{- content }}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n {%- if enable_thinking is defined and enable_thinking is false %}\n {{- '<think>\\n\\n</think>\\n\\n' }}\n {%- endif %}\n{%- endif %}"
|
| 43 |
+
}
|