jingyaogong commited on
Commit
c9d3f05
verified
1 Parent(s): d753ce0

Upload 3 files

Browse files
Files changed (3) hide show
  1. model_minimind.py +39 -15
  2. model_vlm.py +1 -0
  3. tokenizer_config.json +4 -5
model_minimind.py CHANGED
@@ -23,6 +23,7 @@ class MiniMindConfig(PretrainedConfig):
23
  vocab_size: int = 6400,
24
  rms_norm_eps: float = 1e-05,
25
  rope_theta: int = 1000000.0,
 
26
  flash_attn: bool = True,
27
  ####################################################
28
  # Here are the specific configurations of MOE
@@ -52,6 +53,15 @@ class MiniMindConfig(PretrainedConfig):
52
  self.vocab_size = vocab_size
53
  self.rms_norm_eps = rms_norm_eps
54
  self.rope_theta = rope_theta
 
 
 
 
 
 
 
 
 
55
  self.flash_attn = flash_attn
56
  ####################################################
57
  # Here are the specific configurations of MOE
@@ -73,10 +83,11 @@ class MiniMindConfig(PretrainedConfig):
73
 
74
  import math
75
  import torch
 
 
76
  from torch import nn
77
  from transformers.activations import ACT2FN
78
  from typing import Optional, Tuple, List, Union
79
- import torch.nn.functional as F
80
  from transformers import PreTrainedModel, GenerationMixin, PretrainedConfig
81
  from transformers.modeling_outputs import CausalLMOutputWithPast
82
 
@@ -94,8 +105,22 @@ class RMSNorm(torch.nn.Module):
94
  return self.weight * self._norm(x.float()).type_as(x)
95
 
96
 
97
- def precompute_freqs_cis(dim: int, end: int = int(32 * 1024), theta: float = 1e6):
98
- freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  t = torch.arange(end, device=freqs.device)
100
  freqs = torch.outer(t, freqs).float()
101
  freqs_cos = torch.cat([torch.cos(freqs), torch.cos(freqs)], dim=-1)
@@ -118,9 +143,7 @@ def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
118
  if n_rep == 1:
119
  return x
120
  return (
121
- x[:, :, :, None, :]
122
- .expand(bs, slen, num_key_value_heads, n_rep, head_dim)
123
- .reshape(bs, slen, num_key_value_heads * n_rep, head_dim)
124
  )
125
 
126
 
@@ -170,14 +193,14 @@ class Attention(nn.Module):
170
  repeat_kv(xv, self.n_rep).transpose(1, 2)
171
  )
172
 
173
- if False and self.flash and seq_len != 1:
174
- dropout_p = self.dropout if self.training else 0.0
175
- attn_mask = None
176
- if attention_mask is not None:
177
- attn_mask = attention_mask.view(bsz, 1, 1, -1).expand(bsz, self.n_local_heads, seq_len, -1)
178
- attn_mask = attn_mask.bool() if attention_mask is not None else None
179
 
180
- output = F.scaled_dot_product_attention(xq, xk, xv, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=True)
181
  else:
182
  scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim)
183
  scores = scores + torch.triu(
@@ -232,7 +255,6 @@ class MoEGate(nn.Module):
232
  self.reset_parameters()
233
 
234
  def reset_parameters(self) -> None:
235
- import torch.nn.init as init
236
  init.kaiming_uniform_(self.weight, a=math.sqrt(5))
237
 
238
  def forward(self, hidden_states):
@@ -369,7 +391,8 @@ class MiniMindModel(nn.Module):
369
  self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
370
 
371
  freqs_cos, freqs_sin = precompute_freqs_cis(dim=config.hidden_size // config.num_attention_heads,
372
- end=config.max_position_embeddings, theta=config.rope_theta)
 
373
  self.register_buffer("freqs_cos", freqs_cos, persistent=False)
374
  self.register_buffer("freqs_sin", freqs_sin, persistent=False)
375
 
@@ -380,6 +403,7 @@ class MiniMindModel(nn.Module):
380
  use_cache: bool = False,
381
  **kwargs):
382
  batch_size, seq_length = input_ids.shape
 
383
  past_key_values = past_key_values or [None] * len(self.layers)
384
  start_pos = past_key_values[0][0].shape[1] if past_key_values[0] is not None else 0
385
 
 
23
  vocab_size: int = 6400,
24
  rms_norm_eps: float = 1e-05,
25
  rope_theta: int = 1000000.0,
26
+ inference_rope_scaling: bool = False,
27
  flash_attn: bool = True,
28
  ####################################################
29
  # Here are the specific configurations of MOE
 
53
  self.vocab_size = vocab_size
54
  self.rms_norm_eps = rms_norm_eps
55
  self.rope_theta = rope_theta
56
+ self.inference_rope_scaling = inference_rope_scaling
57
+ # 澶栨帹闀垮害 = factor * original_max_position_embeddings
58
+ self.rope_scaling = {
59
+ "beta_fast": 4,
60
+ "beta_slow": 1,
61
+ "factor": 4,
62
+ "original_max_position_embeddings": 2048,
63
+ "type": "yarn"
64
+ } if self.inference_rope_scaling else None
65
  self.flash_attn = flash_attn
66
  ####################################################
67
  # Here are the specific configurations of MOE
 
83
 
84
  import math
85
  import torch
86
+ import torch.nn.init as init
87
+ import torch.nn.functional as F
88
  from torch import nn
89
  from transformers.activations import ACT2FN
90
  from typing import Optional, Tuple, List, Union
 
91
  from transformers import PreTrainedModel, GenerationMixin, PretrainedConfig
92
  from transformers.modeling_outputs import CausalLMOutputWithPast
93
 
 
105
  return self.weight * self._norm(x.float()).type_as(x)
106
 
107
 
108
+ def precompute_freqs_cis(dim: int, end: int = int(32 * 1024), rope_base: float = 1e6,
109
+ rope_scaling: Optional[dict] = None):
110
+ freqs = 1.0 / (rope_base ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
111
+ if rope_scaling is not None:
112
+ orig_max, factor, beta_fast, beta_slow = (
113
+ rope_scaling.get("original_max_position_embeddings", 2048), rope_scaling.get("factor", 4),
114
+ rope_scaling.get("beta_fast", 4.0), rope_scaling.get("beta_slow", 1.0)
115
+ )
116
+ if end / orig_max > 1.0:
117
+ corr_dim = next((i for i in range(dim // 2) if 2 * math.pi / freqs[i] > orig_max), dim // 2)
118
+ power = torch.arange(0, dim // 2, device=freqs.device).float() / max(dim // 2 - 1, 1)
119
+ beta = beta_slow + (beta_fast - beta_slow) * power
120
+ # 位 = (尾路伪 - 尾 + 1)/(尾路伪) YaRN鏍囧噯鍏紡
121
+ scale = torch.where(torch.arange(dim // 2, device=freqs.device) < corr_dim, (beta * factor - beta + 1) / (beta * factor), 1.0 / factor)
122
+ freqs = freqs * scale
123
+
124
  t = torch.arange(end, device=freqs.device)
125
  freqs = torch.outer(t, freqs).float()
126
  freqs_cos = torch.cat([torch.cos(freqs), torch.cos(freqs)], dim=-1)
 
143
  if n_rep == 1:
144
  return x
145
  return (
146
+ x[:, :, :, None, :].expand(bs, slen, num_key_value_heads, n_rep, head_dim).reshape(bs, slen, num_key_value_heads * n_rep, head_dim)
 
 
147
  )
148
 
149
 
 
193
  repeat_kv(xv, self.n_rep).transpose(1, 2)
194
  )
195
 
196
+ if self.flash and seq_len > 1 and (attention_mask is None or torch.all(attention_mask == 1)):
197
+ attn_mask = (
198
+ None
199
+ if attention_mask is None
200
+ else attention_mask.view(bsz, 1, 1, -1).expand(bsz, self.n_local_heads, seq_len, -1).bool()
201
+ )
202
 
203
+ output = F.scaled_dot_product_attention(xq, xk, xv, attn_mask=attn_mask, dropout_p=self.dropout if self.training else 0.0, is_causal=True)
204
  else:
205
  scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim)
206
  scores = scores + torch.triu(
 
255
  self.reset_parameters()
256
 
257
  def reset_parameters(self) -> None:
 
258
  init.kaiming_uniform_(self.weight, a=math.sqrt(5))
259
 
260
  def forward(self, hidden_states):
 
391
  self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
392
 
393
  freqs_cos, freqs_sin = precompute_freqs_cis(dim=config.hidden_size // config.num_attention_heads,
394
+ end=config.max_position_embeddings, rope_base=config.rope_theta,
395
+ rope_scaling=config.rope_scaling)
396
  self.register_buffer("freqs_cos", freqs_cos, persistent=False)
397
  self.register_buffer("freqs_sin", freqs_sin, persistent=False)
398
 
 
403
  use_cache: bool = False,
404
  **kwargs):
405
  batch_size, seq_length = input_ids.shape
406
+ if hasattr(past_key_values, 'layers'): past_key_values = None
407
  past_key_values = past_key_values or [None] * len(self.layers)
408
  start_pos = past_key_values[0][0].shape[1] if past_key_values[0] is not None else 0
409
 
model_vlm.py CHANGED
@@ -119,6 +119,7 @@ class MiniMindVLM(MiniMindForCausalLM):
119
  pixel_values: Optional[torch.FloatTensor] = None,
120
  **args):
121
  batch_size, seq_length = input_ids.shape
 
122
  past_key_values = past_key_values or [None] * len(self.model.layers)
123
  start_pos = past_key_values[0][0].shape[1] if past_key_values[0] is not None else 0
124
 
 
119
  pixel_values: Optional[torch.FloatTensor] = None,
120
  **args):
121
  batch_size, seq_length = input_ids.shape
122
+ if hasattr(past_key_values, 'layers'): past_key_values = None
123
  past_key_values = past_key_values or [None] * len(self.model.layers)
124
  start_pos = past_key_values[0][0].shape[1] if past_key_values[0] is not None else 0
125
 
tokenizer_config.json CHANGED
@@ -30,15 +30,14 @@
30
  },
31
  "additional_special_tokens": [],
32
  "bos_token": "<|im_start|>",
33
- "chat_template": "{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{{ '<|im_start|>system\\n' + system_message + '<|im_end|>\\n' }}{% else %}{{ '<|im_start|>system\\nYou are a helpful assistant<|im_end|>\\n' }}{% endif %}{% for message in messages %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|im_start|>user\\n' + content + '<|im_end|>\\n<|im_start|>assistant\\n' }}{% elif message['role'] == 'assistant' %}{{ content + '<|im_end|>' + '\\n' }}{% endif %}{% endfor %}",
34
  "clean_up_tokenization_spaces": false,
35
  "eos_token": "<|im_end|>",
36
- "extra_special_tokens": {},
37
  "legacy": true,
38
  "model_max_length": 32768,
39
  "pad_token": "<|endoftext|>",
40
  "sp_model_kwargs": {},
41
  "spaces_between_special_tokens": false,
42
- "tokenizer_class": "PreTrainedTokenizer",
43
- "unk_token": "<|endoftext|>"
44
- }
 
 
30
  },
31
  "additional_special_tokens": [],
32
  "bos_token": "<|im_start|>",
 
33
  "clean_up_tokenization_spaces": false,
34
  "eos_token": "<|im_end|>",
 
35
  "legacy": true,
36
  "model_max_length": 32768,
37
  "pad_token": "<|endoftext|>",
38
  "sp_model_kwargs": {},
39
  "spaces_between_special_tokens": false,
40
+ "tokenizer_class": "PreTrainedTokenizerFast",
41
+ "unk_token": "<|endoftext|>",
42
+ "chat_template": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0].role == 'system' %}\n {{- messages[0].content + '\\n\\n' }}\n {%- endif %}\n {{- \"# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' -%}\n {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n {%- else -%}\n {{- '<|im_start|>system\\nYou are a helpful assistant<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}\n{%- for message in messages[::-1] %}\n {%- set index = (messages|length - 1) - loop.index0 %}\n {%- if ns.multi_step_tool and message.role == \"user\" and message.content is string and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}\n {%- set ns.multi_step_tool = false %}\n {%- set ns.last_query_index = index %}\n {%- endif %}\n{%- endfor %}\n{%- for message in messages %}\n {%- if message.content is string %}\n {%- set content = message.content %}\n {%- else %}\n {%- set content = '' %}\n {%- endif %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) %}\n {{- '<|im_start|>' + message.role + '\\n' + content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- if message.tool_calls %}\n {%- for tool_call in message.tool_calls %}\n {%- if (loop.first and content) or (not loop.first) %}\n {{- '\\n' }}\n {%- endif %}\n {%- if tool_call.function %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {%- if tool_call.arguments is string %}\n {{- tool_call.arguments }}\n {%- else %}\n {{- tool_call.arguments | tojson }}\n {%- endif %}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if loop.first or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {{- content }}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n {%- if enable_thinking is defined and enable_thinking is false %}\n {{- '<think>\\n\\n</think>\\n\\n' }}\n {%- endif %}\n{%- endif %}"
43
+ }