liaojc commited on
Commit
a814953
·
verified ·
1 Parent(s): e7f13ee

Add files using upload-large-folder tool

Browse files
LICENSE ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
202
+
README.md ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ERNIE-4.5-300B-A47B
2
+
3
+ ## Features
4
+ Based on ERNIE-4.5-Base's text-specific parameters, we conduct post-training in two phases—supervised fine-tuning (SFT) and reinforcement learning (RL). Utilizing the unified rewarding system which accommodates both reasoning and general tasks, we employ a progressive RL training with the Unified Preference Optimization objective. We obtain superior performance across a broad spectrum of tasks, including creative writing, mathematical reasoning, and code generation. We summarize the features of our model as follows:
5
+
6
+ * **Superior Performance in instruction following tasks**: Demonstrates exceptional capability to understand, prioritize, and execute complex user instructions and achieves top results on both internal human evaluation dataset and public benchmarks for instruction following.
7
+ * **Factual Reliability & Knowledge Accuracy**: Excels on SimpleQA and CN-SimpleQA benchmarks, dramatically reducing hallucinations and ensuring responses are grounded in verifiable facts.
8
+ * **Expressive**, **Persuasive** **responses**: Generates engaging, well-structured responses with personality, clear rationale, and persuasive argumentation — making it ideal for interactive and content-creation scenarios.
9
+ * **Advanced Math & Coding Proficiency**: Delivers precise, step-by-step mathematical reasoning and generates clean, efficient code across multiple programming languages, supporting both algorithmic problem solving and real-world development.
10
+
11
+
12
+
13
+ ## Model Overview
14
+ ERNIE-4.5-300B-A47B is a text MoE Post-trained model, with 300B total parameters and 47B activated parameters for each token. The following are the model configuration details:
15
+
16
+ |Key|Value|
17
+ |-|-|
18
+ |Modality|Text|
19
+ |Training Stage|Pretraining|
20
+ |Params(Total / Activated)|300B / 47B|
21
+ |Layers|54|
22
+ |Heads(Q/KV)|64 / 8|
23
+ |Text Experts(Total / Activated)|64 / 8|
24
+ |Vision Experts(Total / Activated)|64 / 8|
25
+ |Context Length|131072|
26
+
27
+
28
+ ## Quickstart
29
+ ### Model Finetuning with ERNIEKit
30
+
31
+ [ERNIEKit](https://github.com/PaddlePaddle/ERNIE) is a training toolkit based on PaddlePaddle, specifically designed for the ERNIE series of open-source large models. It provides comprehensive support for scenarios such as instruction fine-tuning (SFT, LoRA) and alignment training (DPO), ensuring optimal performance.
32
+
33
+ Usage Examples:
34
+ ```bash
35
+ # SFT
36
+ erniekit train --stage SFT --model_name_or_path baidu/ERNIE-4.5-300B-A47B-Paddle --train_dataset_path your_dataset_path
37
+ # DPO
38
+ erniekit train --stage DPO --model_name_or_path baidu/ERNIE-4.5-300B-A47B-Paddle --train_dataset_path your_dataset_path
39
+ ```
40
+ For more detailed examples, including SFT with LoRA, multi-GPU configurations, and advanced scripts, please refer to the examples folder within the [ERNIEKit](https://github.com/PaddlePaddle/ERNIE) repository.
41
+
42
+
43
+ ### Using `transformers` library
44
+
45
+ The following contains a code snippet illustrating how to use the model generate content based on given inputs.
46
+ ```python
47
+ from transformers import AutoModelForCausalLM, AutoTokenizer
48
+
49
+ model_name = "baidu/ERNIE-4.5-300B-A47B-PT"
50
+
51
+ # load the tokenizer and the model
52
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
53
+ model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", trust_remote_code=True)
54
+
55
+ # prepare the model input
56
+ prompt = "Give me a short introduction to large language model."
57
+ messages = [
58
+ {"role": "user", "content": prompt}
59
+ ]
60
+ text = tokenizer.apply_chat_template(
61
+ messages,
62
+ tokenize=False,
63
+ add_generation_prompt=True
64
+ )
65
+ model_inputs = tokenizer([text], add_special_tokens=False, return_tensors="pt").to(model.device)
66
+
67
+ # conduct text completion
68
+ generated_ids = model.generate(
69
+ model_inputs.input_ids,
70
+ max_new_tokens=1024
71
+ )
72
+ output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
73
+
74
+ # decode the generated ids
75
+ generate_text = tokenizer.decode(output_ids, skip_special_tokens=True).strip("\n")
76
+ print("generate_text:", generate_text)
77
+ ```
78
+
79
+ ### Using FastDeploy
80
+
81
+ Service deployment can be quickly completed using FastDeploy in the following command. For more detailed usage instructions, please refer to the [FastDeploy Repository](https://github.com/PaddlePaddle/FastDeploy).
82
+
83
+ **Note**: To deploy on a configuration with 4 GPUs each having at least 80G of memory, specify ```--quantization wint4```. If you specify ```--quantization wint8```, then resources for 8 GPUs are required.
84
+
85
+ ```bash
86
+ python -m fastdeploy.entrypoints.openai.api_server \
87
+ --model baidu/ERNIE-4.5-300B-A47B-Paddle \
88
+ --port 8180 \
89
+ --metrics-port 8181 \
90
+ --quantization wint4 \
91
+ --tensor-parallel-size 8 \
92
+ --engine-worker-queue-port 8182 \
93
+ --max-model-len 32768 \ # Maximum supported number of tokens
94
+ --max-num-seqs 32 # Maximum concurrent processing capacity
95
+ ```
96
+
97
+ To deploy the W4A8C8 quantized version using FastDeploy, you can run the following command.
98
+
99
+ ```bash
100
+ python -m fastdeploy.entrypoints.openai.api_server \
101
+ --model baidu/ERNIE-4.5-300B-A47B-W4A8C8-TP4 \
102
+ --port 8180 \
103
+ --metrics-port 8181 \
104
+ --engine-worker-queue-port 8182 \
105
+ --tensor-parallel-size 4 \
106
+ --max-model-len 32768 \ # Maximum supported number of tokens
107
+ --max-num-seqs 32 # Maximum concurrent processing capacity
108
+ ```
109
+
110
+ To deploy the WINT2 quantized version using FastDeploy on a single 141G GPU, you can run the following command.
111
+
112
+ ```bash
113
+ python -m fastdeploy.entrypoints.openai.api_server \
114
+ --model "baidu/ERNIE-4.5-300B-A47B-2BITS-Paddle" \
115
+ --port 8180 \
116
+ --metrics-port 8181 \
117
+ --engine-worker-queue-port 8182 \
118
+ --tensor-parallel-size 1 \
119
+ --max-model-len 32768 \ # Maximum supported number of tokens
120
+ --max-num-seqs 128 # Maximum concurrent processing capacity
121
+ ```
122
+
123
+ The following contains a code snippet illustrating how to use ERNIE-4.5-300B-A47B-FP8 generate content based on given inputs.
124
+
125
+ ```bash
126
+ from fastdeploy import LLM, SamplingParams
127
+
128
+ prompts = [
129
+ "Hello, my name is",
130
+ ]
131
+
132
+ sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=128)
133
+
134
+ model = "baidu/ERNIE-4.5-300B-A47B-FP8"
135
+ llm = LLM(model=model, tensor_parallel_size=8, max_model_len=8192, num_gpu_blocks_override=1024, engine_worker_queue_port=9981)
136
+
137
+ outputs = llm.generate(prompts, sampling_params)
138
+
139
+ for output in outputs:
140
+ prompt = output.prompt
141
+ generated_text = output.outputs.text
142
+ print("generated_text", generated_text)
143
+ ```
144
+
145
+ ### Using vLLM
146
+ vLLM is currently being adapted, priority can be given to using our fork repository [vllm](https://github.com/CSWYF3634076/vllm/tree/ernie)
147
+ ```bash
148
+ # 80G * 16 GPU
149
+ vllm serve baidu/ERNIE-4.5-300B-A47B-PT --trust-remote-code
150
+ ```
151
+ ```bash
152
+ # FP8 online quantification 80G * 8 GPU
153
+ vllm serve baidu/ERNIE-4.5-300B-A47B-PT --trust-remote-code --quantization fp8
154
+ ```
155
+
156
+
157
+
158
+ ## Best Practices
159
+ ### **Sampling Parameters**
160
+ To achieve optimal performance, we suggest using `Temperature=0.8`, `TopP=0.8`.
161
+
162
+ ### Prompts for Web Search
163
+ For Web Search, {references}, {date}, and {question} are arguments.
164
+
165
+ For Chinese question, we use the prompt:
166
+
167
+ ```
168
+ ernie_search_zh_prompt = \
169
+ '''下面你会收到当前时间、多个不同来源的参考文章和一段对话。你的任务是阅读多个参考文章,并根据参考文章中的信息回答对话中的问题。
170
+ 以下是当前时间和参考文章:
171
+ ---------
172
+ #当前时间
173
+ {date}
174
+
175
+ #参考文章
176
+ {references}
177
+
178
+ ---------
179
+ 请注意:
180
+ 1. 回答必须结合问题需求和当前时间,对参考文章的可用性进行判断,避免在回答中使用错误或过时的信息。
181
+ 2. 当参考文章中的信息无法准确地回答问题时,你需要在回答中提供获取相应信息的建议,或承认无法提供相应信息。
182
+ 3. 你需要优先根据百科、官网、权威机构、专业网站等高权威性来源的信息来回答问题。
183
+ 4. 回复需要综合参考文章中的相关数字、案例、法律条文、公式等信息,使你的答案更专业。
184
+ 5. 当问题属于创作类任务时,需注意以下维度:
185
+ - 态度鲜明:观点、立场清晰明确,避免模棱两可,语言果断直接
186
+ - 文采飞扬:用词精准生动,善用修辞手法,增强感染力
187
+ - 有理有据:逻辑严密递进,结合权威数据/事实支撑论点
188
+ ---------
189
+ 下面请结合以上信息,回答问题,补全对话
190
+ {question}'''
191
+ ```
192
+ For English question, we use the prompt:
193
+
194
+ ```
195
+ ernie_search_en_prompt = \
196
+ '''
197
+ Below you will be given the current time, multiple references from different sources, and a conversation. Your task is to read the references and use the information in them to answer the question in the conversation.
198
+ Here are the current time and the references:
199
+ ---------
200
+ #Current Time
201
+ {date}
202
+
203
+ #References
204
+ {references}
205
+
206
+ ---------
207
+ Please note:
208
+ 1. Based on the question’s requirements and the current time, assess the usefulness of the references to avoid using inaccurate or outdated information in the answer.
209
+ 2. If the references do not provide enough information to accurately answer the question, you should suggest how to obtain the relevant information or acknowledge that you are unable to provide it.
210
+ 3. Prioritize using information from highly authoritative sources such as encyclopedias, official websites, authoritative institutions, and professional websites when answering questions.
211
+ 4. Incorporate relevant numbers, cases, legal provisions, formulas, and other details from the references to make your answer more professional.
212
+ 5. For creative tasks, keep these dimensions in mind:
213
+ - Clear attitude: Clear views and positions, avoid ambiguity, and use decisive and direct language
214
+ - Brilliant writing: Precise and vivid words, good use of rhetoric, and enhance the appeal
215
+ - Well-reasoned: Rigorous logic and progressive, combined with authoritative data/facts to support the argument
216
+
217
+ ---------
218
+ Now, using the information above, answer the question and complete the conversation:
219
+ {question}'''
220
+ ```
221
+
222
+
223
+ Parameter notes:
224
+
225
+ * {question} is the user’s question
226
+ * {date} is the current time, and the recommended format is “YYYY-MM-DD HH:MM:SS, Day of the Week, Beijing/China.”
227
+ * {references} is the references, and the recommended format is:
228
+
229
+ ```
230
+ ##参考文章1
231
+ 标题:周杰伦
232
+ 文章发布时间:2025-04-20
233
+ 内容:周杰伦(Jay Chou),1979年1月18日出生于台湾省新北市,祖籍福建省永春县,华语流行乐男歌手、音乐人、演员、导演、编剧,毕业于淡江中学。2000年,发行个人首张音乐专辑《Jay》。...
234
+ 来源网站网址:baike.baidu.com
235
+ 来源网站的网站名:百度百科
236
+
237
+ ##参考文章2
238
+ ...
239
+ ```
240
+
241
+
242
+
243
+ ## License
244
+
245
+ The ERNIE 4.5 models are provided under the Apache License 2.0. This license permits commercial use, subject to its terms and conditions. Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
246
+
247
+ ## Citation
248
+
249
+ If you find ERNIE 4.5 useful or wish to use it in your projects, please kindly cite our technical report:
250
+
251
+ ```bibtex
252
+ @misc{ernie2025technicalreport,
253
+ title={ERNIE 4.5 Technical Report},
254
+ author={Baidu ERNIE Team},
255
+ year={2025},
256
+ eprint={},
257
+ archivePrefix={arXiv},
258
+ primaryClass={cs.CL},
259
+ url={}
260
+ }
261
+ ```
added_tokens.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"<|IMAGE_PLACEHOLDER|>": 100295, "<|AUDIO_PLACEHOLDER|>": 100296, "<|LOC_0|>": 100297, "<|LOC_1|>": 100298, "<|LOC_2|>": 100299, "<|LOC_3|>": 100300, "<|LOC_4|>": 100301, "<|LOC_5|>": 100302, "<|LOC_6|>": 100303, "<|LOC_7|>": 100304, "<|LOC_8|>": 100305, "<|LOC_9|>": 100306, "<|LOC_10|>": 100307, "<|LOC_11|>": 100308, "<|LOC_12|>": 100309, "<|LOC_13|>": 100310, "<|LOC_14|>": 100311, "<|LOC_15|>": 100312, "<|LOC_16|>": 100313, "<|LOC_17|>": 100314, "<|LOC_18|>": 100315, "<|LOC_19|>": 100316, "<|LOC_20|>": 100317, "<|LOC_21|>": 100318, "<|LOC_22|>": 100319, "<|LOC_23|>": 100320, "<|LOC_24|>": 100321, "<|LOC_25|>": 100322, "<|LOC_26|>": 100323, "<|LOC_27|>": 100324, "<|LOC_28|>": 100325, "<|LOC_29|>": 100326, "<|LOC_30|>": 100327, "<|LOC_31|>": 100328, "<|LOC_32|>": 100329, "<|LOC_33|>": 100330, "<|LOC_34|>": 100331, "<|LOC_35|>": 100332, "<|LOC_36|>": 100333, "<|LOC_37|>": 100334, "<|LOC_38|>": 100335, "<|LOC_39|>": 100336, "<|LOC_40|>": 100337, "<|LOC_41|>": 100338, "<|LOC_42|>": 100339, "<|LOC_43|>": 100340, "<|LOC_44|>": 100341, "<|LOC_45|>": 100342, "<|LOC_46|>": 100343, "<|LOC_47|>": 100344, "<|LOC_48|>": 100345, "<|LOC_49|>": 100346, "<|LOC_50|>": 100347, "<|LOC_51|>": 100348, "<|LOC_52|>": 100349, "<|LOC_53|>": 100350, "<|LOC_54|>": 100351, "<|LOC_55|>": 100352, "<|LOC_56|>": 100353, "<|LOC_57|>": 100354, "<|LOC_58|>": 100355, "<|LOC_59|>": 100356, "<|LOC_60|>": 100357, "<|LOC_61|>": 100358, "<|LOC_62|>": 100359, "<|LOC_63|>": 100360, "<|LOC_64|>": 100361, "<|LOC_65|>": 100362, "<|LOC_66|>": 100363, "<|LOC_67|>": 100364, "<|LOC_68|>": 100365, "<|LOC_69|>": 100366, "<|LOC_70|>": 100367, "<|LOC_71|>": 100368, "<|LOC_72|>": 100369, "<|LOC_73|>": 100370, "<|LOC_74|>": 100371, "<|LOC_75|>": 100372, "<|LOC_76|>": 100373, "<|LOC_77|>": 100374, "<|LOC_78|>": 100375, "<|LOC_79|>": 100376, "<|LOC_80|>": 100377, "<|LOC_81|>": 100378, "<|LOC_82|>": 100379, "<|LOC_83|>": 100380, "<|LOC_84|>": 100381, "<|LOC_85|>": 100382, "<|LOC_86|>": 100383, "<|LOC_87|>": 100384, "<|LOC_88|>": 100385, "<|LOC_89|>": 100386, "<|LOC_90|>": 100387, "<|LOC_91|>": 100388, "<|LOC_92|>": 100389, "<|LOC_93|>": 100390, "<|LOC_94|>": 100391, "<|LOC_95|>": 100392, "<|LOC_96|>": 100393, "<|LOC_97|>": 100394, "<|LOC_98|>": 100395, "<|LOC_99|>": 100396, "<|LOC_100|>": 100397, "<|LOC_101|>": 100398, "<|LOC_102|>": 100399, "<|LOC_103|>": 100400, "<|LOC_104|>": 100401, "<|LOC_105|>": 100402, "<|LOC_106|>": 100403, "<|LOC_107|>": 100404, "<|LOC_108|>": 100405, "<|LOC_109|>": 100406, "<|LOC_110|>": 100407, "<|LOC_111|>": 100408, "<|LOC_112|>": 100409, "<|LOC_113|>": 100410, "<|LOC_114|>": 100411, "<|LOC_115|>": 100412, "<|LOC_116|>": 100413, "<|LOC_117|>": 100414, "<|LOC_118|>": 100415, "<|LOC_119|>": 100416, "<|LOC_120|>": 100417, "<|LOC_121|>": 100418, "<|LOC_122|>": 100419, "<|LOC_123|>": 100420, "<|LOC_124|>": 100421, "<|LOC_125|>": 100422, "<|LOC_126|>": 100423, "<|LOC_127|>": 100424, "<|LOC_128|>": 100425, "<|LOC_129|>": 100426, "<|LOC_130|>": 100427, "<|LOC_131|>": 100428, "<|LOC_132|>": 100429, "<|LOC_133|>": 100430, "<|LOC_134|>": 100431, "<|LOC_135|>": 100432, "<|LOC_136|>": 100433, "<|LOC_137|>": 100434, "<|LOC_138|>": 100435, "<|LOC_139|>": 100436, "<|LOC_140|>": 100437, "<|LOC_141|>": 100438, "<|LOC_142|>": 100439, "<|LOC_143|>": 100440, "<|LOC_144|>": 100441, "<|LOC_145|>": 100442, "<|LOC_146|>": 100443, "<|LOC_147|>": 100444, "<|LOC_148|>": 100445, "<|LOC_149|>": 100446, "<|LOC_150|>": 100447, "<|LOC_151|>": 100448, "<|LOC_152|>": 100449, "<|LOC_153|>": 100450, "<|LOC_154|>": 100451, "<|LOC_155|>": 100452, "<|LOC_156|>": 100453, "<|LOC_157|>": 100454, "<|LOC_158|>": 100455, "<|LOC_159|>": 100456, "<|LOC_160|>": 100457, "<|LOC_161|>": 100458, "<|LOC_162|>": 100459, "<|LOC_163|>": 100460, "<|LOC_164|>": 100461, "<|LOC_165|>": 100462, "<|LOC_166|>": 100463, "<|LOC_167|>": 100464, "<|LOC_168|>": 100465, "<|LOC_169|>": 100466, "<|LOC_170|>": 100467, "<|LOC_171|>": 100468, "<|LOC_172|>": 100469, "<|LOC_173|>": 100470, "<|LOC_174|>": 100471, "<|LOC_175|>": 100472, "<|LOC_176|>": 100473, "<|LOC_177|>": 100474, "<|LOC_178|>": 100475, "<|LOC_179|>": 100476, "<|LOC_180|>": 100477, "<|LOC_181|>": 100478, "<|LOC_182|>": 100479, "<|LOC_183|>": 100480, "<|LOC_184|>": 100481, "<|LOC_185|>": 100482, "<|LOC_186|>": 100483, "<|LOC_187|>": 100484, "<|LOC_188|>": 100485, "<|LOC_189|>": 100486, "<|LOC_190|>": 100487, "<|LOC_191|>": 100488, "<|LOC_192|>": 100489, "<|LOC_193|>": 100490, "<|LOC_194|>": 100491, "<|LOC_195|>": 100492, "<|LOC_196|>": 100493, "<|LOC_197|>": 100494, "<|LOC_198|>": 100495, "<|LOC_199|>": 100496, "<|LOC_200|>": 100497, "<|LOC_201|>": 100498, "<|LOC_202|>": 100499, "<|LOC_203|>": 100500, "<|LOC_204|>": 100501, "<|LOC_205|>": 100502, "<|LOC_206|>": 100503, "<|LOC_207|>": 100504, "<|LOC_208|>": 100505, "<|LOC_209|>": 100506, "<|LOC_210|>": 100507, "<|LOC_211|>": 100508, "<|LOC_212|>": 100509, "<|LOC_213|>": 100510, "<|LOC_214|>": 100511, "<|LOC_215|>": 100512, "<|LOC_216|>": 100513, "<|LOC_217|>": 100514, "<|LOC_218|>": 100515, "<|LOC_219|>": 100516, "<|LOC_220|>": 100517, "<|LOC_221|>": 100518, "<|LOC_222|>": 100519, "<|LOC_223|>": 100520, "<|LOC_224|>": 100521, "<|LOC_225|>": 100522, "<|LOC_226|>": 100523, "<|LOC_227|>": 100524, "<|LOC_228|>": 100525, "<|LOC_229|>": 100526, "<|LOC_230|>": 100527, "<|LOC_231|>": 100528, "<|LOC_232|>": 100529, "<|LOC_233|>": 100530, "<|LOC_234|>": 100531, "<|LOC_235|>": 100532, "<|LOC_236|>": 100533, "<|LOC_237|>": 100534, "<|LOC_238|>": 100535, "<|LOC_239|>": 100536, "<|LOC_240|>": 100537, "<|LOC_241|>": 100538, "<|LOC_242|>": 100539, "<|LOC_243|>": 100540, "<|LOC_244|>": 100541, "<|LOC_245|>": 100542, "<|LOC_246|>": 100543, "<|LOC_247|>": 100544, "<|LOC_248|>": 100545, "<|LOC_249|>": 100546, "<|LOC_250|>": 100547, "<|LOC_251|>": 100548, "<|LOC_252|>": 100549, "<|LOC_253|>": 100550, "<|LOC_254|>": 100551, "<|LOC_255|>": 100552, "<|LOC_256|>": 100553, "<|LOC_257|>": 100554, "<|LOC_258|>": 100555, "<|LOC_259|>": 100556, "<|LOC_260|>": 100557, "<|LOC_261|>": 100558, "<|LOC_262|>": 100559, "<|LOC_263|>": 100560, "<|LOC_264|>": 100561, "<|LOC_265|>": 100562, "<|LOC_266|>": 100563, "<|LOC_267|>": 100564, "<|LOC_268|>": 100565, "<|LOC_269|>": 100566, "<|LOC_270|>": 100567, "<|LOC_271|>": 100568, "<|LOC_272|>": 100569, "<|LOC_273|>": 100570, "<|LOC_274|>": 100571, "<|LOC_275|>": 100572, "<|LOC_276|>": 100573, "<|LOC_277|>": 100574, "<|LOC_278|>": 100575, "<|LOC_279|>": 100576, "<|LOC_280|>": 100577, "<|LOC_281|>": 100578, "<|LOC_282|>": 100579, "<|LOC_283|>": 100580, "<|LOC_284|>": 100581, "<|LOC_285|>": 100582, "<|LOC_286|>": 100583, "<|LOC_287|>": 100584, "<|LOC_288|>": 100585, "<|LOC_289|>": 100586, "<|LOC_290|>": 100587, "<|LOC_291|>": 100588, "<|LOC_292|>": 100589, "<|LOC_293|>": 100590, "<|LOC_294|>": 100591, "<|LOC_295|>": 100592, "<|LOC_296|>": 100593, "<|LOC_297|>": 100594, "<|LOC_298|>": 100595, "<|LOC_299|>": 100596, "<|LOC_300|>": 100597, "<|LOC_301|>": 100598, "<|LOC_302|>": 100599, "<|LOC_303|>": 100600, "<|LOC_304|>": 100601, "<|LOC_305|>": 100602, "<|LOC_306|>": 100603, "<|LOC_307|>": 100604, "<|LOC_308|>": 100605, "<|LOC_309|>": 100606, "<|LOC_310|>": 100607, "<|LOC_311|>": 100608, "<|LOC_312|>": 100609, "<|LOC_313|>": 100610, "<|LOC_314|>": 100611, "<|LOC_315|>": 100612, "<|LOC_316|>": 100613, "<|LOC_317|>": 100614, "<|LOC_318|>": 100615, "<|LOC_319|>": 100616, "<|LOC_320|>": 100617, "<|LOC_321|>": 100618, "<|LOC_322|>": 100619, "<|LOC_323|>": 100620, "<|LOC_324|>": 100621, "<|LOC_325|>": 100622, "<|LOC_326|>": 100623, "<|LOC_327|>": 100624, "<|LOC_328|>": 100625, "<|LOC_329|>": 100626, "<|LOC_330|>": 100627, "<|LOC_331|>": 100628, "<|LOC_332|>": 100629, "<|LOC_333|>": 100630, "<|LOC_334|>": 100631, "<|LOC_335|>": 100632, "<|LOC_336|>": 100633, "<|LOC_337|>": 100634, "<|LOC_338|>": 100635, "<|LOC_339|>": 100636, "<|LOC_340|>": 100637, "<|LOC_341|>": 100638, "<|LOC_342|>": 100639, "<|LOC_343|>": 100640, "<|LOC_344|>": 100641, "<|LOC_345|>": 100642, "<|LOC_346|>": 100643, "<|LOC_347|>": 100644, "<|LOC_348|>": 100645, "<|LOC_349|>": 100646, "<|LOC_350|>": 100647, "<|LOC_351|>": 100648, "<|LOC_352|>": 100649, "<|LOC_353|>": 100650, "<|LOC_354|>": 100651, "<|LOC_355|>": 100652, "<|LOC_356|>": 100653, "<|LOC_357|>": 100654, "<|LOC_358|>": 100655, "<|LOC_359|>": 100656, "<|LOC_360|>": 100657, "<|LOC_361|>": 100658, "<|LOC_362|>": 100659, "<|LOC_363|>": 100660, "<|LOC_364|>": 100661, "<|LOC_365|>": 100662, "<|LOC_366|>": 100663, "<|LOC_367|>": 100664, "<|LOC_368|>": 100665, "<|LOC_369|>": 100666, "<|LOC_370|>": 100667, "<|LOC_371|>": 100668, "<|LOC_372|>": 100669, "<|LOC_373|>": 100670, "<|LOC_374|>": 100671, "<|LOC_375|>": 100672, "<|LOC_376|>": 100673, "<|LOC_377|>": 100674, "<|LOC_378|>": 100675, "<|LOC_379|>": 100676, "<|LOC_380|>": 100677, "<|LOC_381|>": 100678, "<|LOC_382|>": 100679, "<|LOC_383|>": 100680, "<|LOC_384|>": 100681, "<|LOC_385|>": 100682, "<|LOC_386|>": 100683, "<|LOC_387|>": 100684, "<|LOC_388|>": 100685, "<|LOC_389|>": 100686, "<|LOC_390|>": 100687, "<|LOC_391|>": 100688, "<|LOC_392|>": 100689, "<|LOC_393|>": 100690, "<|LOC_394|>": 100691, "<|LOC_395|>": 100692, "<|LOC_396|>": 100693, "<|LOC_397|>": 100694, "<|LOC_398|>": 100695, "<|LOC_399|>": 100696, "<|LOC_400|>": 100697, "<|LOC_401|>": 100698, "<|LOC_402|>": 100699, "<|LOC_403|>": 100700, "<|LOC_404|>": 100701, "<|LOC_405|>": 100702, "<|LOC_406|>": 100703, "<|LOC_407|>": 100704, "<|LOC_408|>": 100705, "<|LOC_409|>": 100706, "<|LOC_410|>": 100707, "<|LOC_411|>": 100708, "<|LOC_412|>": 100709, "<|LOC_413|>": 100710, "<|LOC_414|>": 100711, "<|LOC_415|>": 100712, "<|LOC_416|>": 100713, "<|LOC_417|>": 100714, "<|LOC_418|>": 100715, "<|LOC_419|>": 100716, "<|LOC_420|>": 100717, "<|LOC_421|>": 100718, "<|LOC_422|>": 100719, "<|LOC_423|>": 100720, "<|LOC_424|>": 100721, "<|LOC_425|>": 100722, "<|LOC_426|>": 100723, "<|LOC_427|>": 100724, "<|LOC_428|>": 100725, "<|LOC_429|>": 100726, "<|LOC_430|>": 100727, "<|LOC_431|>": 100728, "<|LOC_432|>": 100729, "<|LOC_433|>": 100730, "<|LOC_434|>": 100731, "<|LOC_435|>": 100732, "<|LOC_436|>": 100733, "<|LOC_437|>": 100734, "<|LOC_438|>": 100735, "<|LOC_439|>": 100736, "<|LOC_440|>": 100737, "<|LOC_441|>": 100738, "<|LOC_442|>": 100739, "<|LOC_443|>": 100740, "<|LOC_444|>": 100741, "<|LOC_445|>": 100742, "<|LOC_446|>": 100743, "<|LOC_447|>": 100744, "<|LOC_448|>": 100745, "<|LOC_449|>": 100746, "<|LOC_450|>": 100747, "<|LOC_451|>": 100748, "<|LOC_452|>": 100749, "<|LOC_453|>": 100750, "<|LOC_454|>": 100751, "<|LOC_455|>": 100752, "<|LOC_456|>": 100753, "<|LOC_457|>": 100754, "<|LOC_458|>": 100755, "<|LOC_459|>": 100756, "<|LOC_460|>": 100757, "<|LOC_461|>": 100758, "<|LOC_462|>": 100759, "<|LOC_463|>": 100760, "<|LOC_464|>": 100761, "<|LOC_465|>": 100762, "<|LOC_466|>": 100763, "<|LOC_467|>": 100764, "<|LOC_468|>": 100765, "<|LOC_469|>": 100766, "<|LOC_470|>": 100767, "<|LOC_471|>": 100768, "<|LOC_472|>": 100769, "<|LOC_473|>": 100770, "<|LOC_474|>": 100771, "<|LOC_475|>": 100772, "<|LOC_476|>": 100773, "<|LOC_477|>": 100774, "<|LOC_478|>": 100775, "<|LOC_479|>": 100776, "<|LOC_480|>": 100777, "<|LOC_481|>": 100778, "<|LOC_482|>": 100779, "<|LOC_483|>": 100780, "<|LOC_484|>": 100781, "<|LOC_485|>": 100782, "<|LOC_486|>": 100783, "<|LOC_487|>": 100784, "<|LOC_488|>": 100785, "<|LOC_489|>": 100786, "<|LOC_490|>": 100787, "<|LOC_491|>": 100788, "<|LOC_492|>": 100789, "<|LOC_493|>": 100790, "<|LOC_494|>": 100791, "<|LOC_495|>": 100792, "<|LOC_496|>": 100793, "<|LOC_497|>": 100794, "<|LOC_498|>": 100795, "<|LOC_499|>": 100796, "<|LOC_500|>": 100797, "<|LOC_501|>": 100798, "<|LOC_502|>": 100799, "<|LOC_503|>": 100800, "<|LOC_504|>": 100801, "<|LOC_505|>": 100802, "<|LOC_506|>": 100803, "<|LOC_507|>": 100804, "<|LOC_508|>": 100805, "<|LOC_509|>": 100806, "<|LOC_510|>": 100807, "<|LOC_511|>": 100808, "<|LOC_512|>": 100809, "<|LOC_513|>": 100810, "<|LOC_514|>": 100811, "<|LOC_515|>": 100812, "<|LOC_516|>": 100813, "<|LOC_517|>": 100814, "<|LOC_518|>": 100815, "<|LOC_519|>": 100816, "<|LOC_520|>": 100817, "<|LOC_521|>": 100818, "<|LOC_522|>": 100819, "<|LOC_523|>": 100820, "<|LOC_524|>": 100821, "<|LOC_525|>": 100822, "<|LOC_526|>": 100823, "<|LOC_527|>": 100824, "<|LOC_528|>": 100825, "<|LOC_529|>": 100826, "<|LOC_530|>": 100827, "<|LOC_531|>": 100828, "<|LOC_532|>": 100829, "<|LOC_533|>": 100830, "<|LOC_534|>": 100831, "<|LOC_535|>": 100832, "<|LOC_536|>": 100833, "<|LOC_537|>": 100834, "<|LOC_538|>": 100835, "<|LOC_539|>": 100836, "<|LOC_540|>": 100837, "<|LOC_541|>": 100838, "<|LOC_542|>": 100839, "<|LOC_543|>": 100840, "<|LOC_544|>": 100841, "<|LOC_545|>": 100842, "<|LOC_546|>": 100843, "<|LOC_547|>": 100844, "<|LOC_548|>": 100845, "<|LOC_549|>": 100846, "<|LOC_550|>": 100847, "<|LOC_551|>": 100848, "<|LOC_552|>": 100849, "<|LOC_553|>": 100850, "<|LOC_554|>": 100851, "<|LOC_555|>": 100852, "<|LOC_556|>": 100853, "<|LOC_557|>": 100854, "<|LOC_558|>": 100855, "<|LOC_559|>": 100856, "<|LOC_560|>": 100857, "<|LOC_561|>": 100858, "<|LOC_562|>": 100859, "<|LOC_563|>": 100860, "<|LOC_564|>": 100861, "<|LOC_565|>": 100862, "<|LOC_566|>": 100863, "<|LOC_567|>": 100864, "<|LOC_568|>": 100865, "<|LOC_569|>": 100866, "<|LOC_570|>": 100867, "<|LOC_571|>": 100868, "<|LOC_572|>": 100869, "<|LOC_573|>": 100870, "<|LOC_574|>": 100871, "<|LOC_575|>": 100872, "<|LOC_576|>": 100873, "<|LOC_577|>": 100874, "<|LOC_578|>": 100875, "<|LOC_579|>": 100876, "<|LOC_580|>": 100877, "<|LOC_581|>": 100878, "<|LOC_582|>": 100879, "<|LOC_583|>": 100880, "<|LOC_584|>": 100881, "<|LOC_585|>": 100882, "<|LOC_586|>": 100883, "<|LOC_587|>": 100884, "<|LOC_588|>": 100885, "<|LOC_589|>": 100886, "<|LOC_590|>": 100887, "<|LOC_591|>": 100888, "<|LOC_592|>": 100889, "<|LOC_593|>": 100890, "<|LOC_594|>": 100891, "<|LOC_595|>": 100892, "<|LOC_596|>": 100893, "<|LOC_597|>": 100894, "<|LOC_598|>": 100895, "<|LOC_599|>": 100896, "<|LOC_600|>": 100897, "<|LOC_601|>": 100898, "<|LOC_602|>": 100899, "<|LOC_603|>": 100900, "<|LOC_604|>": 100901, "<|LOC_605|>": 100902, "<|LOC_606|>": 100903, "<|LOC_607|>": 100904, "<|LOC_608|>": 100905, "<|LOC_609|>": 100906, "<|LOC_610|>": 100907, "<|LOC_611|>": 100908, "<|LOC_612|>": 100909, "<|LOC_613|>": 100910, "<|LOC_614|>": 100911, "<|LOC_615|>": 100912, "<|LOC_616|>": 100913, "<|LOC_617|>": 100914, "<|LOC_618|>": 100915, "<|LOC_619|>": 100916, "<|LOC_620|>": 100917, "<|LOC_621|>": 100918, "<|LOC_622|>": 100919, "<|LOC_623|>": 100920, "<|LOC_624|>": 100921, "<|LOC_625|>": 100922, "<|LOC_626|>": 100923, "<|LOC_627|>": 100924, "<|LOC_628|>": 100925, "<|LOC_629|>": 100926, "<|LOC_630|>": 100927, "<|LOC_631|>": 100928, "<|LOC_632|>": 100929, "<|LOC_633|>": 100930, "<|LOC_634|>": 100931, "<|LOC_635|>": 100932, "<|LOC_636|>": 100933, "<|LOC_637|>": 100934, "<|LOC_638|>": 100935, "<|LOC_639|>": 100936, "<|LOC_640|>": 100937, "<|LOC_641|>": 100938, "<|LOC_642|>": 100939, "<|LOC_643|>": 100940, "<|LOC_644|>": 100941, "<|LOC_645|>": 100942, "<|LOC_646|>": 100943, "<|LOC_647|>": 100944, "<|LOC_648|>": 100945, "<|LOC_649|>": 100946, "<|LOC_650|>": 100947, "<|LOC_651|>": 100948, "<|LOC_652|>": 100949, "<|LOC_653|>": 100950, "<|LOC_654|>": 100951, "<|LOC_655|>": 100952, "<|LOC_656|>": 100953, "<|LOC_657|>": 100954, "<|LOC_658|>": 100955, "<|LOC_659|>": 100956, "<|LOC_660|>": 100957, "<|LOC_661|>": 100958, "<|LOC_662|>": 100959, "<|LOC_663|>": 100960, "<|LOC_664|>": 100961, "<|LOC_665|>": 100962, "<|LOC_666|>": 100963, "<|LOC_667|>": 100964, "<|LOC_668|>": 100965, "<|LOC_669|>": 100966, "<|LOC_670|>": 100967, "<|LOC_671|>": 100968, "<|LOC_672|>": 100969, "<|LOC_673|>": 100970, "<|LOC_674|>": 100971, "<|LOC_675|>": 100972, "<|LOC_676|>": 100973, "<|LOC_677|>": 100974, "<|LOC_678|>": 100975, "<|LOC_679|>": 100976, "<|LOC_680|>": 100977, "<|LOC_681|>": 100978, "<|LOC_682|>": 100979, "<|LOC_683|>": 100980, "<|LOC_684|>": 100981, "<|LOC_685|>": 100982, "<|LOC_686|>": 100983, "<|LOC_687|>": 100984, "<|LOC_688|>": 100985, "<|LOC_689|>": 100986, "<|LOC_690|>": 100987, "<|LOC_691|>": 100988, "<|LOC_692|>": 100989, "<|LOC_693|>": 100990, "<|LOC_694|>": 100991, "<|LOC_695|>": 100992, "<|LOC_696|>": 100993, "<|LOC_697|>": 100994, "<|LOC_698|>": 100995, "<|LOC_699|>": 100996, "<|LOC_700|>": 100997, "<|LOC_701|>": 100998, "<|LOC_702|>": 100999, "<|LOC_703|>": 101000, "<|LOC_704|>": 101001, "<|LOC_705|>": 101002, "<|LOC_706|>": 101003, "<|LOC_707|>": 101004, "<|LOC_708|>": 101005, "<|LOC_709|>": 101006, "<|LOC_710|>": 101007, "<|LOC_711|>": 101008, "<|LOC_712|>": 101009, "<|LOC_713|>": 101010, "<|LOC_714|>": 101011, "<|LOC_715|>": 101012, "<|LOC_716|>": 101013, "<|LOC_717|>": 101014, "<|LOC_718|>": 101015, "<|LOC_719|>": 101016, "<|LOC_720|>": 101017, "<|LOC_721|>": 101018, "<|LOC_722|>": 101019, "<|LOC_723|>": 101020, "<|LOC_724|>": 101021, "<|LOC_725|>": 101022, "<|LOC_726|>": 101023, "<|LOC_727|>": 101024, "<|LOC_728|>": 101025, "<|LOC_729|>": 101026, "<|LOC_730|>": 101027, "<|LOC_731|>": 101028, "<|LOC_732|>": 101029, "<|LOC_733|>": 101030, "<|LOC_734|>": 101031, "<|LOC_735|>": 101032, "<|LOC_736|>": 101033, "<|LOC_737|>": 101034, "<|LOC_738|>": 101035, "<|LOC_739|>": 101036, "<|LOC_740|>": 101037, "<|LOC_741|>": 101038, "<|LOC_742|>": 101039, "<|LOC_743|>": 101040, "<|LOC_744|>": 101041, "<|LOC_745|>": 101042, "<|LOC_746|>": 101043, "<|LOC_747|>": 101044, "<|LOC_748|>": 101045, "<|LOC_749|>": 101046, "<|LOC_750|>": 101047, "<|LOC_751|>": 101048, "<|LOC_752|>": 101049, "<|LOC_753|>": 101050, "<|LOC_754|>": 101051, "<|LOC_755|>": 101052, "<|LOC_756|>": 101053, "<|LOC_757|>": 101054, "<|LOC_758|>": 101055, "<|LOC_759|>": 101056, "<|LOC_760|>": 101057, "<|LOC_761|>": 101058, "<|LOC_762|>": 101059, "<|LOC_763|>": 101060, "<|LOC_764|>": 101061, "<|LOC_765|>": 101062, "<|LOC_766|>": 101063, "<|LOC_767|>": 101064, "<|LOC_768|>": 101065, "<|LOC_769|>": 101066, "<|LOC_770|>": 101067, "<|LOC_771|>": 101068, "<|LOC_772|>": 101069, "<|LOC_773|>": 101070, "<|LOC_774|>": 101071, "<|LOC_775|>": 101072, "<|LOC_776|>": 101073, "<|LOC_777|>": 101074, "<|LOC_778|>": 101075, "<|LOC_779|>": 101076, "<|LOC_780|>": 101077, "<|LOC_781|>": 101078, "<|LOC_782|>": 101079, "<|LOC_783|>": 101080, "<|LOC_784|>": 101081, "<|LOC_785|>": 101082, "<|LOC_786|>": 101083, "<|LOC_787|>": 101084, "<|LOC_788|>": 101085, "<|LOC_789|>": 101086, "<|LOC_790|>": 101087, "<|LOC_791|>": 101088, "<|LOC_792|>": 101089, "<|LOC_793|>": 101090, "<|LOC_794|>": 101091, "<|LOC_795|>": 101092, "<|LOC_796|>": 101093, "<|LOC_797|>": 101094, "<|LOC_798|>": 101095, "<|LOC_799|>": 101096, "<|LOC_800|>": 101097, "<|LOC_801|>": 101098, "<|LOC_802|>": 101099, "<|LOC_803|>": 101100, "<|LOC_804|>": 101101, "<|LOC_805|>": 101102, "<|LOC_806|>": 101103, "<|LOC_807|>": 101104, "<|LOC_808|>": 101105, "<|LOC_809|>": 101106, "<|LOC_810|>": 101107, "<|LOC_811|>": 101108, "<|LOC_812|>": 101109, "<|LOC_813|>": 101110, "<|LOC_814|>": 101111, "<|LOC_815|>": 101112, "<|LOC_816|>": 101113, "<|LOC_817|>": 101114, "<|LOC_818|>": 101115, "<|LOC_819|>": 101116, "<|LOC_820|>": 101117, "<|LOC_821|>": 101118, "<|LOC_822|>": 101119, "<|LOC_823|>": 101120, "<|LOC_824|>": 101121, "<|LOC_825|>": 101122, "<|LOC_826|>": 101123, "<|LOC_827|>": 101124, "<|LOC_828|>": 101125, "<|LOC_829|>": 101126, "<|LOC_830|>": 101127, "<|LOC_831|>": 101128, "<|LOC_832|>": 101129, "<|LOC_833|>": 101130, "<|LOC_834|>": 101131, "<|LOC_835|>": 101132, "<|LOC_836|>": 101133, "<|LOC_837|>": 101134, "<|LOC_838|>": 101135, "<|LOC_839|>": 101136, "<|LOC_840|>": 101137, "<|LOC_841|>": 101138, "<|LOC_842|>": 101139, "<|LOC_843|>": 101140, "<|LOC_844|>": 101141, "<|LOC_845|>": 101142, "<|LOC_846|>": 101143, "<|LOC_847|>": 101144, "<|LOC_848|>": 101145, "<|LOC_849|>": 101146, "<|LOC_850|>": 101147, "<|LOC_851|>": 101148, "<|LOC_852|>": 101149, "<|LOC_853|>": 101150, "<|LOC_854|>": 101151, "<|LOC_855|>": 101152, "<|LOC_856|>": 101153, "<|LOC_857|>": 101154, "<|LOC_858|>": 101155, "<|LOC_859|>": 101156, "<|LOC_860|>": 101157, "<|LOC_861|>": 101158, "<|LOC_862|>": 101159, "<|LOC_863|>": 101160, "<|LOC_864|>": 101161, "<|LOC_865|>": 101162, "<|LOC_866|>": 101163, "<|LOC_867|>": 101164, "<|LOC_868|>": 101165, "<|LOC_869|>": 101166, "<|LOC_870|>": 101167, "<|LOC_871|>": 101168, "<|LOC_872|>": 101169, "<|LOC_873|>": 101170, "<|LOC_874|>": 101171, "<|LOC_875|>": 101172, "<|LOC_876|>": 101173, "<|LOC_877|>": 101174, "<|LOC_878|>": 101175, "<|LOC_879|>": 101176, "<|LOC_880|>": 101177, "<|LOC_881|>": 101178, "<|LOC_882|>": 101179, "<|LOC_883|>": 101180, "<|LOC_884|>": 101181, "<|LOC_885|>": 101182, "<|LOC_886|>": 101183, "<|LOC_887|>": 101184, "<|LOC_888|>": 101185, "<|LOC_889|>": 101186, "<|LOC_890|>": 101187, "<|LOC_891|>": 101188, "<|LOC_892|>": 101189, "<|LOC_893|>": 101190, "<|LOC_894|>": 101191, "<|LOC_895|>": 101192, "<|LOC_896|>": 101193, "<|LOC_897|>": 101194, "<|LOC_898|>": 101195, "<|LOC_899|>": 101196, "<|LOC_900|>": 101197, "<|LOC_901|>": 101198, "<|LOC_902|>": 101199, "<|LOC_903|>": 101200, "<|LOC_904|>": 101201, "<|LOC_905|>": 101202, "<|LOC_906|>": 101203, "<|LOC_907|>": 101204, "<|LOC_908|>": 101205, "<|LOC_909|>": 101206, "<|LOC_910|>": 101207, "<|LOC_911|>": 101208, "<|LOC_912|>": 101209, "<|LOC_913|>": 101210, "<|LOC_914|>": 101211, "<|LOC_915|>": 101212, "<|LOC_916|>": 101213, "<|LOC_917|>": 101214, "<|LOC_918|>": 101215, "<|LOC_919|>": 101216, "<|LOC_920|>": 101217, "<|LOC_921|>": 101218, "<|LOC_922|>": 101219, "<|LOC_923|>": 101220, "<|LOC_924|>": 101221, "<|LOC_925|>": 101222, "<|LOC_926|>": 101223, "<|LOC_927|>": 101224, "<|LOC_928|>": 101225, "<|LOC_929|>": 101226, "<|LOC_930|>": 101227, "<|LOC_931|>": 101228, "<|LOC_932|>": 101229, "<|LOC_933|>": 101230, "<|LOC_934|>": 101231, "<|LOC_935|>": 101232, "<|LOC_936|>": 101233, "<|LOC_937|>": 101234, "<|LOC_938|>": 101235, "<|LOC_939|>": 101236, "<|LOC_940|>": 101237, "<|LOC_941|>": 101238, "<|LOC_942|>": 101239, "<|LOC_943|>": 101240, "<|LOC_944|>": 101241, "<|LOC_945|>": 101242, "<|LOC_946|>": 101243, "<|LOC_947|>": 101244, "<|LOC_948|>": 101245, "<|LOC_949|>": 101246, "<|LOC_950|>": 101247, "<|LOC_951|>": 101248, "<|LOC_952|>": 101249, "<|LOC_953|>": 101250, "<|LOC_954|>": 101251, "<|LOC_955|>": 101252, "<|LOC_956|>": 101253, "<|LOC_957|>": 101254, "<|LOC_958|>": 101255, "<|LOC_959|>": 101256, "<|LOC_960|>": 101257, "<|LOC_961|>": 101258, "<|LOC_962|>": 101259, "<|LOC_963|>": 101260, "<|LOC_964|>": 101261, "<|LOC_965|>": 101262, "<|LOC_966|>": 101263, "<|LOC_967|>": 101264, "<|LOC_968|>": 101265, "<|LOC_969|>": 101266, "<|LOC_970|>": 101267, "<|LOC_971|>": 101268, "<|LOC_972|>": 101269, "<|LOC_973|>": 101270, "<|LOC_974|>": 101271, "<|LOC_975|>": 101272, "<|LOC_976|>": 101273, "<|LOC_977|>": 101274, "<|LOC_978|>": 101275, "<|LOC_979|>": 101276, "<|LOC_980|>": 101277, "<|LOC_981|>": 101278, "<|LOC_982|>": 101279, "<|LOC_983|>": 101280, "<|LOC_984|>": 101281, "<|LOC_985|>": 101282, "<|LOC_986|>": 101283, "<|LOC_987|>": 101284, "<|LOC_988|>": 101285, "<|LOC_989|>": 101286, "<|LOC_990|>": 101287, "<|LOC_991|>": 101288, "<|LOC_992|>": 101289, "<|LOC_993|>": 101290, "<|LOC_994|>": 101291, "<|LOC_995|>": 101292, "<|LOC_996|>": 101293, "<|LOC_997|>": 101294, "<|LOC_998|>": 101295, "<|LOC_999|>": 101296, "<|LOC_1000|>": 101297, "<|LOC_BEGIN|>": 101298, "<|LOC_END|>": 101299, "<|LOC_SEP|>": 101300, "<|CROP_COL_SEP|>": 101301, "<|CROP_ROW_SEP|>": 101302, "<|IMAGE_SEP|>": 101303}
config.json ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_attn_implementation": "sdpa",
3
+ "architectures": [
4
+ "Ernie4_5_MoeForCausalLM"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_ernie4_5_moe.Ernie4_5_MoeConfig",
8
+ "AutoModel": "modeling_ernie4_5_moe.Ernie4_5_Model",
9
+ "AutoModelForCausalLM": "modeling_ernie4_5_moe.Ernie4_5_MoeForCausalLM"
10
+ },
11
+ "bos_token_id": 1,
12
+ "eos_token_id": 2,
13
+ "hidden_act": "silu",
14
+ "hidden_size": 8192,
15
+ "intermediate_size": 28672,
16
+ "max_position_embeddings": 131072,
17
+ "model_type": "ernie4_5_moe",
18
+ "moe_capacity": [
19
+ 64,
20
+ 64,
21
+ 64
22
+ ],
23
+ "moe_gate": "topk",
24
+ "moe_intermediate_size": 3584,
25
+ "moe_k": 8,
26
+ "moe_layer_interval": 1,
27
+ "moe_layer_start_index": 3,
28
+ "moe_num_experts": 64,
29
+ "moe_use_aux_free": true,
30
+ "num_attention_heads": 64,
31
+ "num_hidden_layers": 54,
32
+ "num_key_value_heads": 8,
33
+ "num_nextn_predict_layers": 1,
34
+ "pad_token_id": 0,
35
+ "rms_norm_eps": 1e-05,
36
+ "rope_theta": 500000,
37
+ "tie_word_embeddings": false,
38
+ "torch_dtype": "bfloat16",
39
+ "use_bias": false,
40
+ "use_cache": true,
41
+ "vocab_size": 103424
42
+ }
configuration_ernie4_5_moe.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Ernie4_5_Moe model configuration"""
15
+
16
+ from transformers import PretrainedConfig
17
+
18
+
19
+
20
+ class Ernie4_5_MoeConfig(PretrainedConfig):
21
+ r"""
22
+ This is the configuration class to store the configuration of a [`Ernie4_5_Model`].
23
+ It is used to instantiate an ERNIE-4.5 model according to the specified arguments,
24
+ defining the model architecture. Instantiating a configuration with the defaults
25
+ will yield a similar configuration to that of ERNIE-4.5-300B-A47B-PT [baidu/ERNIE-4.5-300B-A47B-PT].
26
+
27
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
28
+ documentation from [`PretrainedConfig`] for more information.
29
+
30
+
31
+ Args:
32
+ vocab_size (int): Size of the vocabulary (number of unique tokens)
33
+ hidden_size (int): Dimensionality of the encoder layers and the pooler layer
34
+ intermediate_size (int): Dimensionality of the "intermediate" (feed-forward) layer
35
+ max_position_embeddings (int): Maximum sequence length the model can handle
36
+ num_hidden_layers (int): Number of hidden layers in the Transformer encoder
37
+ num_attention_heads (int): Number of attention heads for each attention layer
38
+ rms_norm_eps (float): The epsilon used by the RMS normalization layers
39
+ use_cache (bool): Whether to use caching for faster generation (decoding)
40
+ use_flash_attention (bool): Whether to use FlashAttention for optimized attention computation
41
+ pad_token_id (int): Token ID used for padding sequences
42
+ bos_token_id (int): Token ID used for beginning-of-sequence
43
+ eos_token_id (int): Token ID used for end-of-sequence
44
+ use_bias (bool): Whether to use bias terms in linear layers
45
+ rope_theta (float): The base period of the RoPE embeddings
46
+ weight_share_add_bias (bool): Whether to share bias weights in certain layers
47
+ ignored_index (int): Target value that is ignored during loss computation
48
+ attention_probs_dropout_prob (float): Dropout probability for attention weights
49
+ hidden_dropout_prob (float): Dropout probability for hidden layers
50
+ num_key_value_heads (int): Number of key/value heads (for Grouped Query Attention)
51
+ max_sequence_length (int): Maximum sequence length for positional embeddings
52
+ moe_num_experts: Number of experts in MoE layers
53
+ moe_capacity: Capacity configuration for MoE layers
54
+ moe_layer_interval: Interval between MoE layers
55
+ moe_layer_start_index: Starting layer index for MoE
56
+ moe_layer_end_index: Ending layer index for MoE (-1 means last layer)
57
+ sinkhorn_2gate: Whether to use sinkhorn 2-gate routing
58
+ sinkhorn_temp: Temperature for sinkhorn routing
59
+ moe_dropout_prob: Dropout probability for MoE layers
60
+ moe_gate: Type of gating mechanism ('top2', etc.)
61
+ moe_intermediate_size: Intermediate size for MoE layers
62
+ moe_gate_act: Activation function for gating
63
+ moe_k: Number of experts to route to
64
+ **kwargs: Additional base model configuration parameters
65
+ """
66
+
67
+ model_type = "ernie4_5_moe"
68
+ use_keep_in_fp32_modules = True
69
+ keys_to_ignore_at_inference = ["past_key_values"]
70
+
71
+ attribute_map = {
72
+ "n_positions": "max_position_embeddings",
73
+ "n_embd": "hidden_size",
74
+ "n_layer": "num_hidden_layers",
75
+ "n_head": "num_attention_heads",
76
+ "n_inner": "intermediate_size",
77
+ "activation_function": "hidden_act",
78
+ }
79
+
80
+ # Default tensor parallel plan for base model `ernie_4_5_moe`
81
+ base_model_tp_plan = {
82
+ "model.layers.*.self_attn.q_proj": "colwise_rep",
83
+ "model.layers.*.self_attn.k_proj": "colwise_rep",
84
+ "model.layers.*.self_attn.v_proj": "colwise_rep",
85
+ "model.layers.*.self_attn.o_proj": "rowwise_rep",
86
+ "model.layers.*.mlp.experts.*.gate_proj": "colwise",
87
+ "model.layers.*.mlp.experts.*.up_proj": "colwise",
88
+ "model.layers.*.mlp.experts.*.down_proj": "rowwise",
89
+ "model.layers.*.mlp.gate_proj": "colwise",
90
+ "model.layers.*.mlp.up_proj": "colwise",
91
+ "model.layers.*.mlp.down_proj": "rowwise",
92
+ }
93
+ base_model_pp_plan = {
94
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
95
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
96
+ "norm": (["hidden_states"], ["hidden_states"]),
97
+ }
98
+
99
+ def __init__(
100
+ self,
101
+ vocab_size=32000,
102
+ hidden_size=768,
103
+ intermediate_size=11008,
104
+ num_hidden_layers=2,
105
+ num_attention_heads=2,
106
+ num_key_value_heads=None,
107
+ max_position_embeddings=32768,
108
+ use_sliding_window=None,
109
+ sliding_window=None,
110
+ rms_norm_eps=1e-6,
111
+ use_cache=False,
112
+ pad_token_id=0,
113
+ bos_token_id=1,
114
+ eos_token_id=2,
115
+ attention_probs_dropout_prob=0.0,
116
+ hidden_dropout_prob=0.0,
117
+ rope_theta=10000.0,
118
+ use_flash_attention=False,
119
+ use_rmsnorm=True,
120
+ use_bias=False,
121
+ weight_share_add_bias=True,
122
+ max_sequence_length=None,
123
+ ignored_index=-100,
124
+ use_moe=True,
125
+ moe_num_experts=64,
126
+ moe_capacity=(64, 64, 64),
127
+ moe_layer_interval=2,
128
+ moe_layer_start_index=0,
129
+ moe_layer_end_index=-1,
130
+ sinkhorn_2gate=True,
131
+ sinkhorn_temp=3e-2,
132
+ moe_dropout_prob=0.0,
133
+ moe_gate="top2",
134
+ moe_intermediate_size=3584,
135
+ moe_k=2,
136
+ moe_gate_act="softmax",
137
+ moe_use_aux_free=False,
138
+ **kwargs
139
+ ):
140
+ self.vocab_size = vocab_size
141
+ self.max_position_embeddings = max_position_embeddings
142
+ self.use_sliding_window = use_sliding_window
143
+ self.sliding_window = sliding_window
144
+ self.hidden_size = hidden_size
145
+ self.intermediate_size = intermediate_size
146
+ self.num_hidden_layers = num_hidden_layers
147
+ self.num_attention_heads = num_attention_heads
148
+
149
+ if num_key_value_heads is None:
150
+ num_key_value_heads = num_attention_heads
151
+
152
+ self.num_key_value_heads = num_key_value_heads
153
+ self.use_rmsnorm = use_rmsnorm
154
+ self.rms_norm_eps = rms_norm_eps
155
+ self.rope_theta = rope_theta
156
+ self.max_sequence_length = max_sequence_length
157
+ self.pad_token_id = pad_token_id
158
+ self.bos_token_id = bos_token_id
159
+ self.eos_token_id = eos_token_id
160
+ self.ignored_index = ignored_index
161
+ self.use_cache = use_cache
162
+ self.use_bias = use_bias
163
+ self.weight_share_add_bias = weight_share_add_bias
164
+ self.use_flash_attention = use_flash_attention
165
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
166
+ self.hidden_dropout_prob = hidden_dropout_prob
167
+
168
+ self.use_moe = moe_num_experts > 0 and use_moe
169
+ self.moe_num_experts = moe_num_experts
170
+ self.moe_capacity = moe_capacity
171
+ self.sinkhorn_2gate = sinkhorn_2gate
172
+ self.sinkhorn_temp = sinkhorn_temp
173
+ self.moe_layer_interval = moe_layer_interval
174
+ self.moe_dropout_prob = moe_dropout_prob
175
+ self.moe_gate = moe_gate
176
+ self.moe_intermediate_size = moe_intermediate_size
177
+ self.moe_k = moe_k
178
+ self.moe_layer_start_index = moe_layer_start_index
179
+ self.moe_layer_end_index = self.num_hidden_layers - 1 if moe_layer_end_index == -1 else moe_layer_end_index
180
+ self.moe_gate_act = moe_gate_act
181
+ self.moe_use_aux_free = moe_use_aux_free
182
+
183
+ # Set default for tied embeddings if not specified.
184
+ if "tie_word_embeddings" not in kwargs:
185
+ kwargs["tie_word_embeddings"] = False
186
+
187
+ super().__init__(
188
+ pad_token_id=pad_token_id,
189
+ bos_token_id=bos_token_id,
190
+ eos_token_id=eos_token_id,
191
+ **kwargs,
192
+ )
generation_config.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "do_sample": true,
3
+ "top_p": 0.8,
4
+ "temperature": 0.8,
5
+ "repetition_penalty": 1.0,
6
+ "frequency_penalty": 0.0,
7
+ "presence_penalty": 0.0,
8
+ "bos_token_id": 1,
9
+ "eos_token_id": 2,
10
+ "pad_token_id": 0,
11
+ "transformers_version": "4.52.4",
12
+ "use_cache": true
13
+ }
model.safetensors.index.json ADDED
The diff for this file is too large to render. See raw diff
 
modeling_ernie4_5_moe.py ADDED
@@ -0,0 +1,1590 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """ Ernie4_5_Moe model """
15
+
16
+ from copy import deepcopy
17
+ from dataclasses import dataclass
18
+ from functools import partial
19
+ from typing import Callable, Optional, Tuple, Union
20
+
21
+ import torch
22
+ import torch.nn.functional as F
23
+ import torch.nn as nn
24
+
25
+ from transformers.cache_utils import (
26
+ Cache,
27
+ DynamicCache,
28
+ SlidingWindowCache,
29
+ StaticCache,
30
+ )
31
+ from transformers.generation import GenerationMixin
32
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
33
+ from transformers.modeling_outputs import ModelOutput, MoeCausalLMOutputWithPast
34
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
35
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
36
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
37
+ from transformers.processing_utils import Unpack
38
+ from transformers.utils import (
39
+ LossKwargs,
40
+ auto_docstring,
41
+ can_return_tuple,
42
+ logging,
43
+ is_torch_flex_attn_available,
44
+ )
45
+
46
+ from .configuration_ernie4_5_moe import Ernie4_5_MoeConfig
47
+
48
+
49
+ if is_torch_flex_attn_available():
50
+ from torch.nn.attention.flex_attention import BlockMask
51
+
52
+ from transformers.integrations.flex_attention import make_flex_block_causal_mask
53
+
54
+ logger = logging.get_logger(__name__)
55
+
56
+
57
+ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs):
58
+ """Kwargs class used during autoregressive generation"""
59
+
60
+ ...
61
+
62
+
63
+ @dataclass
64
+ class Erine4_5_MoeModelOutputWithPast(ModelOutput):
65
+ """Class for Ernie4_5_Moe model outputs with past keys."""
66
+
67
+ last_hidden_state: Optional[torch.FloatTensor] = None
68
+ past_key_values: Optional[Cache] = None
69
+ hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
70
+ attentions: Optional[tuple[torch.FloatTensor, ...]] = None
71
+ router_loss: Optional[torch.FloatTensor] = None
72
+ gate_logits: Optional[tuple[torch.FloatTensor, ...]] = None
73
+
74
+
75
+ @dataclass
76
+ class Ernie4_5_MoeCausalLMOutputWithPast(MoeCausalLMOutputWithPast):
77
+ """Class for Ernie4_5_Moe causal LM output with past keys"""
78
+
79
+ router_loss: Optional[torch.FloatTensor] = None
80
+
81
+
82
+ def rotate_half(x):
83
+ """Rotates half the hidden dims of the input."""
84
+
85
+ x1 = x[..., 0::2]
86
+ x2 = x[..., 1::2]
87
+ return torch.stack((-x2, x1), dim=-1).reshape(x.shape)
88
+
89
+
90
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
91
+ """
92
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
93
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
94
+ """
95
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
96
+ if n_rep == 1:
97
+ return hidden_states
98
+ hidden_states = hidden_states[:, :, None, :, :].expand(
99
+ batch, num_key_value_heads, n_rep, slen, head_dim
100
+ )
101
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
102
+
103
+
104
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
105
+ """Applies Rotary Position Embedding to the query and key tensors.
106
+
107
+ Args:
108
+ q (`torch.Tensor`): The query tensor.
109
+ k (`torch.Tensor`): The key tensor.
110
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
111
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
112
+ position_ids (`torch.Tensor`, *optional*):
113
+ Deprecated and unused.
114
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
115
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
116
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
117
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
118
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
119
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
120
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
121
+ Returns:
122
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
123
+ """
124
+ orig_dtype = q.dtype
125
+ sin_pos = torch.stack([sin, sin], dim=-1).reshape(*sin.shape[:-1], -1)
126
+ cos_pos = torch.stack([cos, cos], dim=-1).reshape(*sin.shape[:-1], -1)
127
+ q_embed = (q.float() * cos_pos) + (rotate_half(q).float() * sin_pos)
128
+ k_embed = (k.float() * cos_pos) + (rotate_half(k).float() * sin_pos)
129
+ return q_embed.to(orig_dtype), k_embed.to(orig_dtype)
130
+
131
+
132
+ def eager_attention_forward(
133
+ module: nn.Module,
134
+ query: torch.Tensor,
135
+ key: torch.Tensor,
136
+ value: torch.Tensor,
137
+ attention_mask: Optional[torch.Tensor],
138
+ scaling: float,
139
+ dropout: float = 0.0,
140
+ **kwargs,
141
+ ):
142
+ """
143
+ Eager attention for Ernie4_5_Attention forward function.
144
+ """
145
+ key_states = repeat_kv(key, module.num_key_value_groups)
146
+ value_states = repeat_kv(value, module.num_key_value_groups)
147
+
148
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
149
+ if attention_mask is not None:
150
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
151
+ attn_weights = attn_weights + causal_mask.to(attn_weights.device)
152
+
153
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
154
+ query.dtype
155
+ )
156
+ attn_weights = nn.functional.dropout(
157
+ attn_weights, p=dropout, training=module.training
158
+ )
159
+ attn_output = torch.matmul(attn_weights, value_states)
160
+ attn_output = attn_output.transpose(1, 2).contiguous()
161
+
162
+ return attn_output, attn_weights
163
+
164
+
165
+ def topk_gate_func(
166
+ module: nn.Module,
167
+ hidden_states: torch.Tensor,
168
+ ):
169
+ """
170
+ Topk gate function for Ernie4_5_MoEMlp
171
+ """
172
+ capacity = module.get_capacity(hidden_states.shape[0])
173
+ with torch.autocast(device_type="cuda", dtype=torch.float32):
174
+ logits = module.gate(hidden_states.float())
175
+ router_loss = torch.zeros([1], dtype=torch.float32, device=hidden_states.device)
176
+ router_loss.detach()
177
+ return logits, capacity, router_loss
178
+
179
+
180
+ class Ernie4_5_ResidualWithDropout(nn.Module):
181
+ """
182
+ Fused dropout implementation with residual connection support.
183
+
184
+ This layer combines dropout and residual addition in a single operation for better performance,
185
+ particularly on GPU devices. The dropout is conditionally applied based on the probability.
186
+
187
+ Args:
188
+ prob (float): Dropout probability (between 0 and 1)
189
+
190
+ Attributes:
191
+ prob (float): Stores the dropout probability
192
+ dropout (nn.Dropout): The actual dropout layer instance
193
+ """
194
+
195
+ def __init__(self, prob):
196
+ """
197
+ Initialize the fused dropout layer.
198
+
199
+ Args:
200
+ prob (float): Dropout probability (0 means no dropout)
201
+ """
202
+ super().__init__()
203
+ self.prob = prob
204
+ self.dropout = nn.Dropout(p=prob)
205
+
206
+ def forward(self, x, y):
207
+ """
208
+ Forward pass of the fused dropout layer.
209
+
210
+ Args:
211
+ x (torch.Tensor): Input tensor to potentially apply dropout on
212
+ y (torch.Tensor): Residual tensor to add to the (possibly dropped out) x
213
+
214
+ Returns:
215
+ torch.Tensor: Result of x (with optional dropout) + y
216
+ """
217
+ if self.prob > 0:
218
+ x = self.dropout(x)
219
+ output = x + y
220
+
221
+ return output
222
+
223
+
224
+ class Ernie4_5_Attention(nn.Module):
225
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
226
+
227
+ def __init__(self, config, layer_idx=0):
228
+ """
229
+ Args:
230
+ config (ErnieConfig): Model configuration.
231
+ layer_idx (int, optional): Index in transformer stack. Defaults to 0.
232
+ """
233
+ super().__init__()
234
+ self.layer_idx = layer_idx
235
+ self.hidden_size = config.hidden_size
236
+ self.num_heads = config.num_attention_heads
237
+ self.num_key_value_heads = (
238
+ config.num_key_value_heads
239
+ if config.num_key_value_heads is not None
240
+ else self.nums_head
241
+ )
242
+ self.num_key_value_groups = (
243
+ config.num_attention_heads // config.num_key_value_heads
244
+ )
245
+ self.head_dim = self.hidden_size // self.num_heads
246
+ self.freq_allocation = (
247
+ config.freq_allocation if hasattr(config, "freq_allocation") else 0
248
+ )
249
+ self.scaling = self.head_dim**-0.5
250
+ self.attention_dropout = getattr(config, "attention_probs_dropout_prob", 0.0)
251
+ self.is_causal = True
252
+
253
+ self.q_proj = nn.Linear(
254
+ self.hidden_size,
255
+ self.num_heads * self.head_dim,
256
+ bias=config.use_bias,
257
+ )
258
+
259
+ self.k_proj = nn.Linear(
260
+ self.hidden_size,
261
+ self.num_key_value_heads * self.head_dim,
262
+ bias=config.use_bias,
263
+ )
264
+
265
+ self.v_proj = nn.Linear(
266
+ self.hidden_size,
267
+ self.num_key_value_heads * self.head_dim,
268
+ bias=config.use_bias,
269
+ )
270
+
271
+ self.o_proj = nn.Linear(
272
+ self.hidden_size,
273
+ self.hidden_size,
274
+ bias=config.use_bias,
275
+ )
276
+
277
+ self.config = config
278
+
279
+ def forward(
280
+ self,
281
+ hidden_states: torch.Tensor,
282
+ attention_mask: Optional[torch.Tensor] = None,
283
+ past_key_value: Optional[Cache] = None,
284
+ position_ids: Optional[torch.Tensor] = None,
285
+ cache_position: Optional[torch.LongTensor] = None,
286
+ position_embeddings: tuple[torch.Tensor, torch.Tensor] = None,
287
+ **kwargs: Unpack[FlashAttentionKwargs],
288
+ ) -> Tuple[
289
+ torch.Tensor,
290
+ Optional[torch.Tensor],
291
+ Optional[Tuple[torch.Tensor, torch.Tensor]],
292
+ ]:
293
+ """
294
+ Ernie4_5_Attention forward function
295
+ """
296
+ B, L = hidden_states.shape[:-1]
297
+
298
+ query_states = (
299
+ self.q_proj(hidden_states).view(B, L, self.num_heads, -1).transpose(1, 2)
300
+ )
301
+ key_states = (
302
+ self.k_proj(hidden_states)
303
+ .view(B, L, self.num_key_value_heads, -1)
304
+ .transpose(1, 2)
305
+ )
306
+ value_states = (
307
+ self.v_proj(hidden_states)
308
+ .view(B, L, self.num_key_value_heads, -1)
309
+ .transpose(1, 2)
310
+ )
311
+
312
+ cos, sin = position_embeddings
313
+ query_states, key_states = apply_rotary_pos_emb(
314
+ query_states, key_states, cos, sin
315
+ )
316
+
317
+ if past_key_value is not None:
318
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
319
+ cache_kwargs = {"cache_position": cache_position}
320
+ key_states, value_states = past_key_value.update(
321
+ key_states, value_states, self.layer_idx, cache_kwargs
322
+ )
323
+
324
+ attention_interface: Callable = eager_attention_forward
325
+ if self.config._attn_implementation != "eager":
326
+ attention_interface = ALL_ATTENTION_FUNCTIONS[
327
+ self.config._attn_implementation
328
+ ]
329
+
330
+ attn_output, attn_weights = attention_interface(
331
+ self,
332
+ query_states,
333
+ key_states,
334
+ value_states,
335
+ attention_mask,
336
+ dropout=0.0 if not self.training else self.attention_dropout,
337
+ scaling=self.scaling,
338
+ **kwargs,
339
+ )
340
+ attn_output = attn_output.reshape(B, L, -1).contiguous()
341
+ attn_output = self.o_proj(attn_output)
342
+
343
+ return attn_output, attn_weights
344
+
345
+
346
+ class Ernie4_5_MLP(nn.Module):
347
+ """
348
+ Ernie4_5_MLP - Gated Multi-Layer Perceptron module used in Ernie model.
349
+ """
350
+
351
+ def __init__(self, config, intermediate_size=None):
352
+ """
353
+ Initialize the MLP module with configuration options.
354
+
355
+ Args:
356
+ config: Model configuration object with attributes:
357
+ - hidden_size: int
358
+ - intermediate_size: int
359
+ - use_bias: bool
360
+ layer_idx (int): Index of current layer (default: 0)
361
+ """
362
+ super().__init__()
363
+ self.config = config
364
+ self.hidden_size = config.hidden_size
365
+ self.intermediate_size = (
366
+ intermediate_size
367
+ if intermediate_size is not None
368
+ else config.intermediate_size
369
+ )
370
+ self.gate_proj = nn.Linear(
371
+ self.hidden_size, self.intermediate_size, bias=config.use_bias
372
+ )
373
+ self.up_proj = nn.Linear(
374
+ self.hidden_size, self.intermediate_size, bias=config.use_bias
375
+ )
376
+ self.down_proj = nn.Linear(
377
+ self.intermediate_size, self.hidden_size, bias=config.use_bias
378
+ )
379
+
380
+ def forward(self, x):
381
+ """
382
+ Args:
383
+ x (Tensor): shape [batch_size, seq_len, hidden_size]
384
+
385
+ Returns:
386
+ Tensor: shape [batch_size, seq_len, hidden_size]
387
+ """
388
+ down_proj = self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
389
+ return down_proj
390
+
391
+
392
+ class Ernie4_5_MoeStatics(nn.Module):
393
+ """
394
+ Stores MoE (Mixture of Experts) statistics
395
+ and expert usage information.
396
+ """
397
+
398
+ def __init__(self, config):
399
+ """
400
+ Initialize MoE statistics tracking.
401
+
402
+ Args:
403
+ config: Model configuration containing MoE parameters
404
+ """
405
+ super().__init__()
406
+
407
+ num_experts = config.moe_num_experts
408
+ num_experts_groups = 1
409
+
410
+ self.e_score_correction_bias = nn.Parameter(
411
+ torch.zeros(num_experts_groups, num_experts, dtype=torch.float32),
412
+ requires_grad=False,
413
+ )
414
+
415
+
416
+ class Ernie4_5_MoeMLP(nn.Module):
417
+ """Mixture of Experts (MoE) variant of ERNIE's MLP layer."""
418
+
419
+ def __init__(self, config):
420
+ super().__init__()
421
+ self.config = config
422
+ self.k = config.moe_k
423
+ self.sinkhorn_2gate = config.sinkhorn_2gate
424
+ self.sinkhorn_temp = config.sinkhorn_temp
425
+
426
+ moe_intermediate_size = (
427
+ config.moe_intermediate_size
428
+ if config.moe_intermediate_size
429
+ else config.intermediate_size
430
+ )
431
+ self.gate = nn.Linear(
432
+ config.hidden_size, config.moe_num_experts, bias=False, dtype=torch.float32
433
+ )
434
+ if config.moe_gate_act == "softmax":
435
+ self.gate_act = partial(F.softmax, dim=-1)
436
+ elif config.moe_gate_act == "sigmoid":
437
+ self.gate_act = F.sigmoid
438
+ else:
439
+ raise ValueError(f"{config.moe_gate_act} is not supported.")
440
+
441
+ self.experts = nn.ModuleList(
442
+ [
443
+ Ernie4_5_MLP(config, moe_intermediate_size)
444
+ for i in range(config.moe_num_experts)
445
+ ]
446
+ )
447
+
448
+ if config.moe_use_aux_free:
449
+ self.moe_statics = Ernie4_5_MoeStatics(config)
450
+
451
+ self.use_correction_bias = config.moe_use_aux_free
452
+ self.num_local_experts = len(self.experts)
453
+
454
+ self.shared_experts = self._init_shared_experts()
455
+
456
+ def _init_shared_experts(self):
457
+ """
458
+ Initialize the shared expert module.
459
+
460
+ Returns:
461
+ shared_experts: Shared expert module, returns None if no shared experts are needed.
462
+
463
+ """
464
+ cfg = deepcopy(self.config)
465
+ if getattr(cfg, "moe_num_shared_experts", 0) > 0:
466
+ if getattr(cfg, "moe_intermediate_size", None):
467
+ cfg.intermediate_size = (
468
+ cfg.moe_intermediate_size * cfg.moe_num_shared_experts
469
+ )
470
+ else:
471
+ cfg.intermediate_size = (
472
+ cfg.intermediate_size * cfg.moe_num_shared_experts
473
+ )
474
+ shared_experts = Ernie4_5_MLP(cfg, cfg.intermediate_size)
475
+ else:
476
+ shared_experts = None
477
+ return shared_experts
478
+
479
+ def forward(
480
+ self,
481
+ input: torch.Tensor,
482
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
483
+ """
484
+ Forward pass through MoE layer.
485
+
486
+ Args:
487
+ input (Tensor): Input tensor of shape [s, d].
488
+ token_type_ids: Optional tensor for token types.
489
+
490
+ Returns:
491
+ tuple: (output, combine_weights, router_loss, gate_logits)
492
+ """
493
+
494
+ if input.dim() == 3:
495
+ orig_shape = input.shape
496
+ input = input.reshape(-1, input.shape[-1])
497
+ else:
498
+ orig_shape = None
499
+ assert (
500
+ input.dim() == 2
501
+ ), f"input Tensor must have dimensions: (s)equence, (d)im, got:{input.shape}"
502
+
503
+ assert self.gate is not None
504
+
505
+ gate_input = input
506
+
507
+ (
508
+ dispatched_input,
509
+ combine_weights,
510
+ dispatch_mask,
511
+ scatter_index,
512
+ router_loss,
513
+ gate_logits,
514
+ gate_prob,
515
+ ) = self.gate_and_dispatch(gate_input)
516
+
517
+ expert_out = self.forward_experts(dispatched_input)
518
+
519
+ combined_output = self.combine_expert_output(
520
+ expert_out, combine_weights, scatter_index
521
+ )
522
+
523
+ if self.shared_experts is not None:
524
+ shared_expert_out = self.shared_experts(gate_input)
525
+ combined_output += shared_expert_out
526
+
527
+ if orig_shape:
528
+ combined_output = combined_output.reshape(
529
+ orig_shape[:-1] + (combined_output.shape[-1],)
530
+ )
531
+
532
+ return combined_output, combine_weights, router_loss, gate_logits
533
+
534
+ def forward_experts(self, dispatched_input: torch.Tensor) -> torch.Tensor:
535
+ """
536
+ Forward pass through experts sequentially.
537
+
538
+ Args:
539
+ dispatched_input (Tensor): Input tensor of shape [num_experts, capacity, dim].
540
+
541
+ Returns:
542
+ Tensor: Expert outputs of shape [num_experts, capacity, dim].
543
+ """
544
+ true_experts = self.experts
545
+ dispatched_input = dispatched_input.reshape(
546
+ 1, self.num_local_experts, -1, dispatched_input.shape[-1]
547
+ )
548
+ expert_outputs = []
549
+ if isinstance(self.experts, nn.ModuleList):
550
+ chunks = dispatched_input.permute(1, 0, 2, 3).contiguous().unbind(0)
551
+ assert len(chunks) == len(
552
+ true_experts
553
+ ), f"{len(chunks)}, {len(true_experts)}"
554
+ for chunk, expert in zip(chunks, true_experts):
555
+ expert_outputs.append(expert(chunk))
556
+ else:
557
+ dispatched_input = dispatched_input.permute(1, 0, 2, 3).contiguous()
558
+ orig_shape = dispatched_input.shape
559
+ chunks = dispatched_input.reshape(orig_shape[0], -1, orig_shape[-1])
560
+ chunks = self.experts(chunks)
561
+ chunks = chunks.reshape(orig_shape[:-1] + (chunks.shape[-1],)).unbind(0)
562
+ expert_outputs.extend(chunks)
563
+
564
+ expert_output = torch.stack(expert_outputs, dim=1)
565
+ return expert_output
566
+
567
+ def moe_gate_dispatch(
568
+ self,
569
+ x: torch.Tensor,
570
+ gate_logits: torch.Tensor,
571
+ k: int,
572
+ capacity: Optional[int],
573
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
574
+ """
575
+ Dispatch inputs to experts based on their routing probabilities.
576
+ """
577
+ S, H = x.shape
578
+ E = gate_logits.shape[1]
579
+ device = x.device
580
+ topk_prob, topk_idx = torch.topk(gate_logits, k, dim=-1)
581
+ combine_weights = topk_prob
582
+ expert_id = topk_idx
583
+ y = x.new_zeros((E, capacity, H))
584
+ scatter_index = x.new_full((k, S), -1, dtype=torch.int32)
585
+
586
+ # per-expert slot counters
587
+ slot_counter = torch.zeros(E, dtype=torch.int32, device=device)
588
+
589
+ for tok in range(S):
590
+ for route in range(k):
591
+ e = expert_id[tok, route].item()
592
+ slot = slot_counter[e].item()
593
+ if slot >= capacity:
594
+ combine_weights[tok, route] = 0.0
595
+ continue
596
+
597
+ # record mapping & dispatch activation
598
+ scatter_index[route, tok] = e * capacity + slot
599
+ y[e, slot] = x[tok]
600
+ slot_counter[e] += 1
601
+
602
+ expert_offset = torch.cumsum(slot_counter, 0, dtype=torch.int64)
603
+
604
+ return y, combine_weights, scatter_index, expert_offset, expert_id
605
+
606
+ def combine_expert_output(
607
+ self,
608
+ expert_output: torch.Tensor,
609
+ combine_weights: torch.Tensor,
610
+ scatter_index: torch.Tensor,
611
+ ) -> torch.Tensor:
612
+ """
613
+ Combine expert outputs using combination weights.
614
+
615
+ Args:
616
+ expert_output (Tensor): Expert outputs [num_experts, capacity, dim].
617
+ combine_weights (Tensor): Combination weights.
618
+ scatter_index (Tensor): Scatter indices.
619
+
620
+ Returns:
621
+ Tensor: Combined output [seqlen, dim].
622
+ """
623
+ expert_output = expert_output.reshape(-1, expert_output.shape[-1])
624
+ combined_output = self.combining(expert_output, combine_weights, scatter_index)
625
+ return combined_output
626
+
627
+ def combining(self, x, combine_weights, scatter_index):
628
+ """
629
+ Combines and aggregates input matrix using combination weights.
630
+
631
+ Args:
632
+ x (Tensor): Input tensor of shape [num_experts * capacity, dim]
633
+ combine_weights (Tensor): Combination weights of shape [seq, 2]
634
+ scatter_index (Tensor): Scatter indices of shape [seq, 2]
635
+
636
+ Returns:
637
+ Tensor: Combined output tensor of shape [seq, dim]
638
+ """
639
+ dim = x.shape[-1]
640
+
641
+ scatter_index = scatter_index.reshape([-1])
642
+ num_k = combine_weights.shape[-1]
643
+
644
+ combine_weights = combine_weights.unsqueeze(1)
645
+
646
+ x = x[scatter_index].reshape([-1, num_k, dim])
647
+
648
+ return torch.matmul(combine_weights, x).squeeze(1)
649
+
650
+ def gate_and_dispatch(self, input):
651
+ """
652
+ Calculate gate and dispatch inputs.
653
+
654
+ Args:
655
+ input: Input tensor of shape [seq, dim]
656
+
657
+ Returns:
658
+ tuple: (dispatched_input, combine_weights, dispatch_mask,
659
+ scatter_index, router_loss, gate_logits, gate_prob)
660
+ """
661
+ gate_logits, capacity, router_loss = topk_gate_func(self, input)
662
+
663
+ # capacity no use
664
+ prob = self.gate_act(gate_logits)
665
+ (
666
+ dispatched_input,
667
+ combine_weights_unnorm,
668
+ scatter_index,
669
+ dispatch_mask,
670
+ _,
671
+ ) = self.moe_gate_dispatch(input, prob, k=self.k, capacity=capacity)
672
+ dispatch_mask = torch.diff(F.pad(dispatch_mask, (1, 0)))
673
+
674
+ scatter_index.detach()
675
+ dispatch_mask.detach()
676
+
677
+ scatter_index = scatter_index.transpose(0, 1) # [k, s] -> [s, k]
678
+ combine_weights = combine_weights_unnorm / torch.clamp(
679
+ combine_weights_unnorm.sum(dim=-1, keepdim=True), min=1e-12
680
+ )
681
+ combine_weights = combine_weights.to(dtype=dispatched_input.dtype)
682
+
683
+ return (
684
+ dispatched_input,
685
+ combine_weights,
686
+ dispatch_mask,
687
+ scatter_index,
688
+ router_loss,
689
+ gate_logits,
690
+ prob,
691
+ )
692
+
693
+ def get_capacity(self, num_tokens, cap_factor=None):
694
+ """
695
+ Calculate capacity based on number of tokens.
696
+
697
+ Args:
698
+ num_tokens: Number of input tokens
699
+ cap_factor: Optional capacity factor override
700
+
701
+ Returns:
702
+ int: Calculated capacity
703
+ """
704
+ num_experts = self.config.moe_num_experts
705
+ if cap_factor is not None:
706
+ cap = cap_factor
707
+ else:
708
+ if self.training:
709
+ cap = self.config.moe_capacity[0]
710
+ elif num_tokens < num_experts:
711
+ cap = self.config.moe_capacity[2]
712
+ else:
713
+ cap = self.config.moe_capacity[1]
714
+
715
+ capacity = int(cap * num_tokens // num_experts)
716
+ assert (
717
+ capacity > 0
718
+ ), f"requires capacity to >= 0. cap={cap}, num_tokens={num_tokens}"
719
+ return capacity
720
+
721
+
722
+ class Ernie4_5_RMSNorm(nn.Module):
723
+ """
724
+ Ernie Root Mean Square Layer Normalization (Ernie4_5_RMSNorm) implementation.
725
+
726
+ Ernie4_5_RMSNorm is a simplified version of LayerNorm that focuses on the root mean square of inputs,
727
+ omitting the mean-centering operation. This provides computational efficiency while maintaining
728
+ good performance.
729
+
730
+ """
731
+
732
+ def __init__(self, config):
733
+ """
734
+ Initialize RMSNorm layer.
735
+
736
+ Args:
737
+ config (ErnieConfig): Model configuration.
738
+ """
739
+ super().__init__()
740
+ self.config = config
741
+ self.hidden_size = config.hidden_size
742
+ self.weight = nn.Parameter(torch.ones(config.hidden_size))
743
+ self.variance_epsilon = config.rms_norm_eps
744
+
745
+ def forward(self, hidden_states):
746
+ """
747
+ Apply RMS normalization to input hidden states.
748
+
749
+ Args:
750
+ hidden_states (Tensor): Input tensor of shape [batch_size, seq_len, hidden_size]
751
+
752
+ Returns:
753
+ Tensor: Normalized output tensor of same shape as input
754
+ """
755
+ input_dtype = hidden_states.dtype
756
+ hidden_states = hidden_states.to(torch.float32)
757
+ variance = hidden_states.pow(2).mean(dim=-1, keepdim=True)
758
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
759
+
760
+ return self.weight * hidden_states.to(input_dtype)
761
+
762
+
763
+ class Ernie4_5_RopeEmbedding(nn.Module):
764
+ """
765
+ Implements Rotary Position Embedding (RoPE) for Ernie4_5_MoeModel.
766
+ """
767
+
768
+ def __init__(self, config: Ernie4_5_MoeConfig, device=None):
769
+ super().__init__()
770
+ # BC: "rope_type" was originally "type"
771
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
772
+ self.rope_type = config.rope_scaling.get(
773
+ "rope_type", config.rope_scaling.get("type")
774
+ )
775
+ else:
776
+ self.rope_type = "default"
777
+ self.max_seq_len_cached = config.max_position_embeddings
778
+ self.original_max_seq_len = config.max_position_embeddings
779
+
780
+ self.config = config
781
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
782
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
783
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
784
+ self.original_inv_freq = self.inv_freq
785
+
786
+ @torch.no_grad()
787
+ def forward(self, x, position_ids):
788
+ inv_freq_expanded = self.inv_freq[None, None, :].float()
789
+ position_ids_expanded = position_ids[..., None].float()
790
+ freqs = inv_freq_expanded.float() * position_ids_expanded.float()
791
+ cos = torch.cos(freqs) * self.attention_scaling
792
+ sin = torch.sin(freqs) * self.attention_scaling
793
+ return cos, sin
794
+ # return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
795
+
796
+
797
+ class Ernie4_5_DecoderLayer(nn.Module):
798
+ """A single transformer decoder layer in ERNIE-MoE model.
799
+
800
+ Contains self-attention and feed-forward components with optional MoE (Mixture of Experts)
801
+ support, residual connections, and layer normalization.
802
+ """
803
+
804
+ def __init__(self, config, layer_idx):
805
+ """Initialize the decoder layer.
806
+
807
+ Args:
808
+ config (ErnieMoEConfig): Model configuration.
809
+ layer_idx (int): Index of this layer in the transformer stack
810
+ """
811
+ super().__init__()
812
+ self.hidden_size = config.hidden_size
813
+ self.layer_idx = layer_idx
814
+ self.config = config
815
+ self.use_moe = config.use_moe
816
+ self.self_attn = Ernie4_5_Attention(config, layer_idx)
817
+
818
+ moe_layer_start_index = (
819
+ min(config.moe_layer_start_index)
820
+ if isinstance(config.moe_layer_start_index, (tuple, list))
821
+ else config.moe_layer_start_index
822
+ )
823
+ moe_layer_end_index = (
824
+ max(config.moe_layer_end_index)
825
+ if isinstance(config.moe_layer_end_index, (tuple, list))
826
+ else config.moe_layer_end_index
827
+ )
828
+
829
+ if (
830
+ self.use_moe
831
+ and ((layer_idx + 1) % config.moe_layer_interval == 0)
832
+ and layer_idx >= moe_layer_start_index
833
+ and layer_idx <= moe_layer_end_index
834
+ ):
835
+ self.mlp = Ernie4_5_MoeMLP(config)
836
+ else:
837
+ self.mlp = Ernie4_5_MLP(config)
838
+
839
+ self.input_layernorm = Ernie4_5_RMSNorm(config)
840
+ self.post_attention_layernorm = Ernie4_5_RMSNorm(config)
841
+
842
+ self.residual_add1 = Ernie4_5_ResidualWithDropout(config.hidden_dropout_prob)
843
+ self.residual_add2 = Ernie4_5_ResidualWithDropout(config.hidden_dropout_prob)
844
+
845
+ def forward(
846
+ self,
847
+ hidden_states: torch.Tensor,
848
+ attention_mask: Optional[torch.Tensor] = None,
849
+ position_ids: Optional[torch.Tensor] = None,
850
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
851
+ output_attentions: Optional[bool] = False,
852
+ use_cache: Optional[bool] = False,
853
+ cache_position: Optional[torch.LongTensor] = None,
854
+ position_embeddings: Optional[
855
+ tuple[torch.Tensor, torch.Tensor]
856
+ ] = None, # necessary, but kept here for BC
857
+ output_router_loss: bool = True,
858
+ output_gate_logits: bool = True,
859
+ **kwargs: Unpack[FlashAttentionKwargs],
860
+ ) -> tuple[
861
+ torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]
862
+ ]:
863
+ """Forward pass through the decoder layer.
864
+
865
+ Args:
866
+ hidden_states (torch.Tensor): Input tensor [batch_size, seq_len, hidden_size]
867
+ attention_mask (Optional[torch.Tensor]): Attention mask tensor
868
+ position_ids (Optional[torch.Tensor]): Position indices for rotary embeddings
869
+ past_key_value (Optional[Tuple[torch.Tensor]]): Cached key/value states
870
+ output_attentions (Optional[bool]): Whether to return attention weights
871
+ use_cache (Optional[bool]): Whether to cache key/value states
872
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
873
+ Indices depicting the position of the input sequence tokens in the sequence.
874
+ position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
875
+ Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
876
+ with `head_dim` being the embedding dimension of each attention head.
877
+ output_router_loss (bool): Whether to return MoE router loss
878
+ output_gate_logits (bool): Whether to return MoE gate logits
879
+
880
+ Returns:
881
+ Union: Various output combinations depending on arguments:
882
+ - Base case: Hidden states tensor
883
+ - With attention: Tuple of (hidden_states, attention_weights)
884
+ - With router loss: May include gate logits in output tuple
885
+ - With MoE gate logits: May include gate logits in output tuple
886
+ """
887
+ residual = hidden_states
888
+
889
+ hidden_states = self.input_layernorm(hidden_states)
890
+
891
+ # Self Attention
892
+ hidden_states, self_attn_weights = self.self_attn(
893
+ hidden_states=hidden_states,
894
+ attention_mask=attention_mask,
895
+ past_key_value=past_key_value,
896
+ position_ids=position_ids,
897
+ use_cache=use_cache,
898
+ cache_position=cache_position,
899
+ position_embeddings=position_embeddings,
900
+ **kwargs,
901
+ )
902
+
903
+ hidden_states = self.residual_add1(hidden_states, residual)
904
+
905
+ # Fully Connected
906
+ residual = hidden_states
907
+ hidden_states = self.post_attention_layernorm(hidden_states)
908
+
909
+ router_loss = None
910
+ gate_logits = None
911
+
912
+ if isinstance(self.mlp, Ernie4_5_MoeMLP):
913
+ hidden_states, _, router_loss, gate_logits = self.mlp(hidden_states)
914
+ else:
915
+ hidden_states = self.mlp(hidden_states)
916
+
917
+ hidden_states = self.residual_add2(hidden_states, residual)
918
+
919
+ outputs = (hidden_states,)
920
+
921
+ if output_attentions:
922
+ outputs += (self_attn_weights,)
923
+
924
+ if output_router_loss:
925
+ outputs += (router_loss,)
926
+
927
+ if output_gate_logits:
928
+ outputs += (gate_logits,)
929
+
930
+ return outputs
931
+
932
+
933
+ @auto_docstring
934
+ class Ernie4_5_PretrainedModel(PreTrainedModel):
935
+ """Base class for ERNIE pretrained models."""
936
+
937
+ config_class = Ernie4_5_MoeConfig
938
+ base_model_prefix = "model"
939
+ supports_gradient_checkpointing = True
940
+ _no_split_modules = ["Ernie4_5_DecoderLayer"]
941
+ _skip_keys_device_placement = ["past_key_values"]
942
+ _supports_flash_attn_2 = True
943
+ _supports_sdpa = True
944
+ _supports_flex_attn = True
945
+ _supports_cache_class = True
946
+ _supports_quantized_cache = True
947
+ _supports_static_cache = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
948
+
949
+
950
+ def subbatch(f, arg_idx, axis, bs, out_idx, same_arg_idx={}):
951
+ """
952
+ Converts a function to one that applies to subbatch of an input dimension.
953
+ Useful for processing large tensors in smaller chunks to reduce memory usage.
954
+
955
+ Args:
956
+ f (Callable): Function to be subbatched.
957
+ arg_idx ([int]): Indices of the inputs to be subbatched.
958
+ axis ([int]): Indices of the dimensions to be subbatched for each input.
959
+ bs (int): Subbatch size.
960
+ out_idx (int): Dimension to concatenate outputs along.
961
+ same_arg_idx (dict): Mapping of argument indices that share the same tensor.
962
+
963
+ Returns:
964
+ Callable: New function that processes inputs in subbatches.
965
+ """
966
+
967
+ @functools.wraps(f)
968
+ def wrapper(*args, **kwargs):
969
+
970
+ assert len(arg_idx) == len(
971
+ axis
972
+ ), "Number of batching args and number of batching dims should match."
973
+
974
+ inps = [args[i] for i in arg_idx]
975
+ axis_width = [inp.shape[d] for inp, d in zip(inps, axis)]
976
+ assert len(set(axis_width)) == 1, "Batch sizes should be kept equal."
977
+
978
+ inp_axis = {idx: d for idx, d in zip(arg_idx, axis)}
979
+
980
+ axis_width = axis_width[0]
981
+ if axis_width < bs:
982
+ return f(*args, **kwargs)
983
+
984
+ outs = []
985
+ for slice_at in range(0, axis_width, bs):
986
+ _args = []
987
+ for i, inp in enumerate(args):
988
+ if i in same_arg_idx:
989
+ assert (
990
+ i > same_arg_idx[i]
991
+ ), f"expect i > same_arg_idx[i], but got i: {i} and same_arg_idx[i]: {same_arg_idx[i]}"
992
+ _args.append(_args[same_arg_idx[i]])
993
+ elif i in arg_idx:
994
+ d = inp_axis[i]
995
+ start = slice_at
996
+ end = min(inp.shape[d], slice_at + bs)
997
+ # Build slice for all dims, only slice along axis d
998
+ slices = [slice(None)] * inp.ndim
999
+ slices[d] = slice(start, end)
1000
+ _args.append(inp[tuple(slices)])
1001
+ else:
1002
+ _args.append(inp)
1003
+
1004
+ out = f(*_args, **kwargs)
1005
+ outs.append(out)
1006
+
1007
+ return torch.cat(outs, dim=out_idx)
1008
+
1009
+ return wrapper
1010
+
1011
+
1012
+ class ErniePretrainingCriterion(nn.Module):
1013
+ """Criterion for ERNIE pretraining task."""
1014
+
1015
+ def __init__(self, config, return_tuple=True):
1016
+ """Initialize the pretraining criterion.
1017
+
1018
+ Args:
1019
+ config (ErnieConfig): Model configuration.
1020
+ return_tuple (bool): Whether to return loss as tuple (loss, loss_sum). Defaults to True.
1021
+ """
1022
+ super().__init__()
1023
+ self.ignored_index = getattr(config, "ignored_index", -100)
1024
+ self.config = config
1025
+ self.return_tuple = return_tuple
1026
+
1027
+ self.loss_func = nn.CrossEntropyLoss(reduction="none")
1028
+
1029
+ def forward(self, prediction_scores, masked_lm_labels, loss_mask, router_loss=None):
1030
+ """Compute the combined pretraining loss.
1031
+
1032
+ Args:
1033
+ prediction_scores: Prediction scores tensor, [batch_size, seq_len, vocab_size]
1034
+ masked_lm_labels: Target labels tensor [batch_size, seq_len]
1035
+ loss_mask: Optional mask for valid tokens
1036
+ router_loss: Optional MoE router loss tensor
1037
+
1038
+ Returns:
1039
+ Union:
1040
+ - If return_tuple=True: Tuple of (combined_loss, mlm_loss_sum)
1041
+ - If return_tuple=False: Combined loss tensor
1042
+ """
1043
+ res = self.forward_impl(prediction_scores, masked_lm_labels, loss_mask)
1044
+
1045
+ if self.return_tuple:
1046
+ loss, loss_sum = res
1047
+ else:
1048
+ loss, loss_sum = res, None
1049
+
1050
+ if router_loss is not None and isinstance(router_loss, torch.Tensor):
1051
+ loss = loss + router_loss - router_loss.detach()
1052
+
1053
+ return loss, loss_sum
1054
+
1055
+ def loss_impl(
1056
+ self, prediction_scores: torch.Tensor, masked_lm_labels: torch.Tensor
1057
+ ) -> torch.Tensor:
1058
+ """
1059
+ Core loss computation without reduction (but per-token).
1060
+
1061
+ Args:
1062
+ prediction_scores (torch.Tensor): Logits tensor [batch_size, seq_len, vocab_size].
1063
+ masked_lm_labels (torch.Tensor): Target labels tensor [batch_size, seq_len].
1064
+
1065
+ Returns:
1066
+ torch.Tensor: Unreduced loss tensor of shape [batch_size, seq_len].
1067
+ Losses are calculated in float32.
1068
+ """
1069
+ scores_float32 = prediction_scores.to(torch.float32)
1070
+ # prediction_scores: [batch_size, seq_len, vocab_size]
1071
+ # masked_lm_labels: [batch_size, seq_len]
1072
+ # Transpose prediction_scores to [batch_size, vocab_size, seq_len]
1073
+ unreduced_loss = self.loss_func(
1074
+ scores_float32.transpose(1, 2), # Shape: [batch_size, vocab_size, seq_len]
1075
+ masked_lm_labels.long(), # Shape: [batch_size, seq_len], ensure long type
1076
+ )
1077
+ # unreduced_loss will be of shape [batch_size, seq_len] and dtype float32
1078
+ return unreduced_loss
1079
+
1080
+ def forward_impl(self, prediction_scores, masked_lm_labels, loss_mask=None):
1081
+ """
1082
+ Loss function forward pass implementation.
1083
+ """
1084
+ prediction_scores_dims = len(prediction_scores.shape)
1085
+
1086
+ loss_subbatch_seqlen_config_key = "loss_subbatch_seqlen"
1087
+ default_loss_subbatch_seqlen = 32768
1088
+
1089
+ current_loss_subbatch_seqlen = self.config.get(
1090
+ loss_subbatch_seqlen_config_key, default_loss_subbatch_seqlen
1091
+ )
1092
+
1093
+ if (
1094
+ prediction_scores_dims == 2
1095
+ and prediction_scores.shape[0] > current_loss_subbatch_seqlen
1096
+ ):
1097
+ sb_loss_func = subbatch(
1098
+ self.loss_impl, [0, 1], [0, 0], current_loss_subbatch_seqlen, 0
1099
+ )
1100
+ masked_lm_loss = sb_loss_func(prediction_scores, masked_lm_labels)
1101
+ elif (
1102
+ prediction_scores_dims == 3
1103
+ and prediction_scores.shape[1] > current_loss_subbatch_seqlen
1104
+ ):
1105
+ sb_loss_func = subbatch(
1106
+ self.loss_impl, [0, 1], [1, 1], current_loss_subbatch_seqlen, 1
1107
+ )
1108
+ masked_lm_loss = sb_loss_func(prediction_scores, masked_lm_labels)
1109
+ else:
1110
+ masked_lm_loss = self.loss_impl(prediction_scores, masked_lm_labels)
1111
+
1112
+ if loss_mask is None:
1113
+ loss_mask = masked_lm_labels != self.ignored_index
1114
+
1115
+ loss_mask = loss_mask.reshape(-1).to(torch.float32)
1116
+
1117
+ masked_lm_loss = torch.sum(
1118
+ masked_lm_loss.to(torch.float32).reshape(-1) * loss_mask
1119
+ )
1120
+
1121
+ # The division will be in float32
1122
+ loss = masked_lm_loss / loss_mask.sum()
1123
+
1124
+ loss_sum = masked_lm_loss.sum().detach()
1125
+
1126
+ if not self.return_tuple:
1127
+ if self.training:
1128
+ return loss
1129
+ return loss_sum
1130
+ return loss, loss_sum
1131
+
1132
+
1133
+ @auto_docstring
1134
+ class Ernie4_5_Model(Ernie4_5_PretrainedModel):
1135
+ """The core ERNIE transformer model with MoE (Mixture of Experts) support."""
1136
+
1137
+ _keep_in_fp32_modules = ["gate"]
1138
+
1139
+ def __init__(self, config: Ernie4_5_MoeConfig):
1140
+ """Initialize the ERNIE model architecture."""
1141
+ super().__init__(config)
1142
+ self.padding_idx = config.pad_token_id
1143
+ self.vocab_size = config.vocab_size
1144
+ self.hidden_size = config.hidden_size
1145
+ self.config = config
1146
+
1147
+ self.embed_tokens = nn.Embedding(
1148
+ self.vocab_size,
1149
+ self.hidden_size,
1150
+ )
1151
+
1152
+ self.layers = nn.ModuleList(
1153
+ [Ernie4_5_DecoderLayer(config, i) for i in range(config.num_hidden_layers)]
1154
+ )
1155
+ self.norm = Ernie4_5_RMSNorm(config)
1156
+ self.rotary_emb = Ernie4_5_RopeEmbedding(config=config)
1157
+
1158
+ self.gradient_checkpointing = False
1159
+
1160
+ self.post_init()
1161
+
1162
+ def get_input_embeddings(self):
1163
+ """Get the input embedding layer."""
1164
+ return self.embed_tokens
1165
+
1166
+ def set_input_embeddings(self, value):
1167
+ """Set new input embeddings."""
1168
+ self.embed_tokens = value
1169
+
1170
+ def forward(
1171
+ self,
1172
+ input_ids: Optional[torch.LongTensor] = None,
1173
+ attention_mask: Optional[torch.Tensor] = None,
1174
+ position_ids: Optional[torch.LongTensor] = None,
1175
+ past_key_values: Optional[Cache] = None,
1176
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1177
+ use_cache: Optional[bool] = None,
1178
+ output_attentions: Optional[bool] = None,
1179
+ output_hidden_states: Optional[bool] = None,
1180
+ cache_position: Optional[torch.LongTensor] = None,
1181
+ **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
1182
+ ):
1183
+ """Forward pass through the ERNIE model."""
1184
+ output_attentions = (
1185
+ output_attentions
1186
+ if output_attentions is not None
1187
+ else self.config.output_attentions
1188
+ )
1189
+ output_hidden_states = (
1190
+ output_hidden_states
1191
+ if output_hidden_states is not None
1192
+ else self.config.output_hidden_states
1193
+ )
1194
+
1195
+ if (input_ids is None) ^ (inputs_embeds is not None):
1196
+ raise ValueError(
1197
+ "You must specify exactly one of input_ids or inputs_embeds"
1198
+ )
1199
+
1200
+ if self.gradient_checkpointing and self.training:
1201
+ if use_cache:
1202
+ logger.warning_once(
1203
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
1204
+ )
1205
+ use_cache = False
1206
+
1207
+ if use_cache and past_key_values is None:
1208
+ past_key_values = DynamicCache()
1209
+
1210
+ if inputs_embeds is None:
1211
+ inputs_embeds = self.embed_tokens(input_ids)
1212
+
1213
+ inputs_embeds = inputs_embeds.to(self.embed_tokens.weight.dtype)
1214
+
1215
+ if cache_position is None:
1216
+ past_seen_tokens = (
1217
+ past_key_values.get_seq_length() if past_key_values is not None else 0
1218
+ )
1219
+ cache_position = torch.arange(
1220
+ past_seen_tokens,
1221
+ past_seen_tokens + inputs_embeds.shape[1],
1222
+ device=inputs_embeds.device,
1223
+ )
1224
+ if position_ids is None:
1225
+ position_ids = cache_position.unsqueeze(0)
1226
+
1227
+ causal_mask = self._update_causal_mask(
1228
+ attention_mask,
1229
+ inputs_embeds,
1230
+ cache_position,
1231
+ past_key_values,
1232
+ output_attentions,
1233
+ )
1234
+
1235
+ hidden_states = inputs_embeds
1236
+
1237
+ # create position embeddings to be shared across the decoder layers
1238
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
1239
+
1240
+ # decoder layers
1241
+ all_hidden_states = () if output_hidden_states else None
1242
+ all_self_attns = () if output_attentions else None
1243
+ all_router_loss = (
1244
+ torch.tensor(0.0, device=inputs_embeds.device)
1245
+ if self.config.use_moe
1246
+ else None
1247
+ )
1248
+ all_gate_logits = ()
1249
+
1250
+ for decoder_layer in self.layers:
1251
+ if output_hidden_states:
1252
+ all_hidden_states += (hidden_states,)
1253
+
1254
+ if self.gradient_checkpointing and self.training:
1255
+ layer_outputs = self._gradient_checkpointing_func(
1256
+ partial(decoder_layer.__call__, **flash_attn_kwargs),
1257
+ hidden_states,
1258
+ causal_mask,
1259
+ position_ids,
1260
+ past_key_values,
1261
+ output_attentions,
1262
+ use_cache,
1263
+ cache_position,
1264
+ position_embeddings,
1265
+ )
1266
+ else:
1267
+ layer_outputs = decoder_layer(
1268
+ hidden_states,
1269
+ causal_mask,
1270
+ position_ids,
1271
+ past_key_values,
1272
+ output_attentions,
1273
+ use_cache,
1274
+ cache_position,
1275
+ position_embeddings,
1276
+ **flash_attn_kwargs,
1277
+ )
1278
+
1279
+ hidden_states = layer_outputs[0]
1280
+
1281
+ if output_attentions:
1282
+ all_self_attns += (layer_outputs[1],)
1283
+
1284
+ if self.config.use_moe:
1285
+ layer_outputs, gate_logits = layer_outputs[:-1], layer_outputs[-1]
1286
+ all_gate_logits = all_gate_logits + (gate_logits,)
1287
+
1288
+ hidden_states = self.norm(hidden_states)
1289
+
1290
+ # add hidden states from the last decoder layer
1291
+ if output_hidden_states:
1292
+ all_hidden_states += (hidden_states,)
1293
+
1294
+ # assert all_router_loss is None, f'moe not support `return-dict`'
1295
+ return Erine4_5_MoeModelOutputWithPast(
1296
+ last_hidden_state=hidden_states,
1297
+ past_key_values=past_key_values,
1298
+ hidden_states=all_hidden_states,
1299
+ attentions=all_self_attns,
1300
+ router_loss=all_router_loss,
1301
+ gate_logits=all_gate_logits,
1302
+ )
1303
+
1304
+ def _update_causal_mask(
1305
+ self,
1306
+ attention_mask: Union[torch.Tensor, "BlockMask"],
1307
+ input_tensor: torch.Tensor,
1308
+ cache_position: torch.Tensor,
1309
+ past_key_values: Cache,
1310
+ output_attentions: bool = False,
1311
+ ):
1312
+ if self.config._attn_implementation == "flash_attention_2":
1313
+ if attention_mask is not None and past_key_values is not None:
1314
+ is_padding_right = (
1315
+ attention_mask[:, -1].sum().item() != input_tensor.size()[0]
1316
+ )
1317
+ if is_padding_right:
1318
+ raise ValueError(
1319
+ "You are attempting to perform batched generation with padding_side='right'"
1320
+ " this may lead to unexpected behaviour for Flash Attention version of Qwen3. Make sure to "
1321
+ " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
1322
+ )
1323
+ if attention_mask is not None and 0.0 in attention_mask:
1324
+ return attention_mask
1325
+ return None
1326
+ if self.config._attn_implementation == "flex_attention":
1327
+ if isinstance(attention_mask, torch.Tensor):
1328
+ attention_mask = make_flex_block_causal_mask(attention_mask)
1329
+ return attention_mask
1330
+
1331
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
1332
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
1333
+ # to infer the attention mask.
1334
+ past_seen_tokens = (
1335
+ past_key_values.get_seq_length() if past_key_values is not None else 0
1336
+ )
1337
+ using_static_cache = isinstance(past_key_values, StaticCache)
1338
+ using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)
1339
+
1340
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
1341
+ if (
1342
+ self.config._attn_implementation == "sdpa"
1343
+ and not (using_static_cache or using_sliding_window_cache)
1344
+ and not output_attentions
1345
+ ):
1346
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
1347
+ attention_mask,
1348
+ inputs_embeds=input_tensor,
1349
+ past_key_values_length=past_seen_tokens,
1350
+ sliding_window=self.config.sliding_window,
1351
+ is_training=self.training,
1352
+ ):
1353
+ return None
1354
+
1355
+ dtype = input_tensor.dtype
1356
+ min_dtype = torch.finfo(dtype).min
1357
+ sequence_length = input_tensor.shape[1]
1358
+ # SlidingWindowCache or StaticCache
1359
+ if using_sliding_window_cache or using_static_cache:
1360
+ target_length = past_key_values.get_max_cache_shape()
1361
+ # DynamicCache or no cache
1362
+ else:
1363
+ target_length = (
1364
+ attention_mask.shape[-1]
1365
+ if isinstance(attention_mask, torch.Tensor)
1366
+ else past_seen_tokens + sequence_length + 1
1367
+ )
1368
+
1369
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
1370
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
1371
+ attention_mask,
1372
+ sequence_length=sequence_length,
1373
+ target_length=target_length,
1374
+ dtype=dtype,
1375
+ cache_position=cache_position,
1376
+ batch_size=input_tensor.shape[0],
1377
+ config=self.config,
1378
+ past_key_values=past_key_values,
1379
+ )
1380
+
1381
+ if (
1382
+ self.config._attn_implementation == "sdpa"
1383
+ and attention_mask is not None
1384
+ and attention_mask.device.type in ["cuda", "xpu", "npu"]
1385
+ and not output_attentions
1386
+ ):
1387
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
1388
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
1389
+ # Details: https://github.com/pytorch/pytorch/issues/110213
1390
+ causal_mask = AttentionMaskConverter._unmask_unattended(
1391
+ causal_mask, min_dtype
1392
+ )
1393
+
1394
+ return causal_mask
1395
+
1396
+ @staticmethod
1397
+ def _prepare_4d_causal_attention_mask_with_cache_position(
1398
+ attention_mask: torch.Tensor,
1399
+ sequence_length: int,
1400
+ target_length: int,
1401
+ dtype: torch.dtype,
1402
+ cache_position: torch.Tensor,
1403
+ batch_size: int,
1404
+ config: Ernie4_5_MoeConfig,
1405
+ past_key_values: Cache,
1406
+ ):
1407
+ """
1408
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
1409
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
1410
+
1411
+ Args:
1412
+ attention_mask (`torch.Tensor`):
1413
+ A 2D attention mask of shape `(batch_size, key_value_length)`,
1414
+ or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
1415
+ sequence_length (`int`):
1416
+ The sequence length being processed.
1417
+ target_length (`int`):
1418
+ The target length: when generating with static cache, the mask should be as long as the static cache,
1419
+ to account for the 0 padding, the part of the cache that is not filled yet.
1420
+ dtype (`torch.dtype`):
1421
+ The dtype to use for the 4D attention mask.
1422
+ cache_position (`torch.Tensor`):
1423
+ Indices depicting the position of the input sequence tokens in the sequence.
1424
+ batch_size (`torch.Tensor`):
1425
+ Batch size.
1426
+ config (`Ernie4_5_MoeConfig`):
1427
+ The model's configuration class
1428
+ past_key_values (`Cache`):
1429
+ The cache class that is being used currently to generate
1430
+ """
1431
+ if attention_mask is not None and attention_mask.dim() == 4:
1432
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
1433
+ causal_mask = attention_mask
1434
+ else:
1435
+ min_dtype = torch.finfo(dtype).min
1436
+ causal_mask = torch.full(
1437
+ (sequence_length, target_length),
1438
+ fill_value=min_dtype,
1439
+ dtype=dtype,
1440
+ device=cache_position.device,
1441
+ )
1442
+ diagonal_attend_mask = torch.arange(
1443
+ target_length, device=cache_position.device
1444
+ ) > cache_position.reshape(-1, 1)
1445
+ text_config = config.get_text_config()
1446
+ if (
1447
+ getattr(text_config, "use_sliding_window", True)
1448
+ and text_config.sliding_window is not None
1449
+ ):
1450
+ if (
1451
+ not isinstance(past_key_values, SlidingWindowCache)
1452
+ or sequence_length > target_length
1453
+ ):
1454
+ sliding_attend_mask = torch.arange(
1455
+ target_length, device=cache_position.device
1456
+ ) <= (cache_position.reshape(-1, 1) - text_config.sliding_window)
1457
+ diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
1458
+ causal_mask *= diagonal_attend_mask
1459
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
1460
+ if attention_mask is not None:
1461
+ causal_mask = (
1462
+ causal_mask.clone()
1463
+ ) # copy to contiguous memory for in-place edit
1464
+ if attention_mask.shape[-1] > target_length:
1465
+ attention_mask = attention_mask[:, :target_length]
1466
+ mask_length = attention_mask.shape[-1]
1467
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[
1468
+ :, None, None, :
1469
+ ].to(causal_mask.device)
1470
+ padding_mask = padding_mask == 0
1471
+ causal_mask[:, :, :, :mask_length] = causal_mask[
1472
+ :, :, :, :mask_length
1473
+ ].masked_fill(padding_mask, min_dtype)
1474
+ return causal_mask
1475
+
1476
+
1477
+ @auto_docstring
1478
+ class Ernie4_5_MoeForCausalLM(Ernie4_5_PretrainedModel, GenerationMixin):
1479
+ """ERNIE Mixture of Experts (MoE) model for causal language modeling."""
1480
+
1481
+ _tied_weights_keys = ["lm_head.weight"]
1482
+ _tp_plan = {"lm_head": "colwise_rep"}
1483
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
1484
+
1485
+ def __init__(self, config):
1486
+ """
1487
+ Initializes the ERNIE MoE model for causal language modeling.
1488
+
1489
+ Args:
1490
+ config (dict): Model configuration.
1491
+ """
1492
+ super().__init__(config)
1493
+ self.config = config
1494
+ self.model = Ernie4_5_Model(config)
1495
+ self.lm_head = nn.Linear(
1496
+ config.hidden_size,
1497
+ config.vocab_size,
1498
+ bias=config.weight_share_add_bias and config.use_bias,
1499
+ ) # TODO
1500
+ self.loss_function = ErniePretrainingCriterion(config)
1501
+
1502
+ # Initialize weights and apply final processing
1503
+ self.post_init()
1504
+
1505
+ def get_input_embeddings(self):
1506
+ """Returns the input embeddings layer."""
1507
+ return self.model.embed_tokens
1508
+
1509
+ def set_input_embeddings(self, value):
1510
+ """Sets the input embeddings layer."""
1511
+ self.ernie.embed_tokens = value
1512
+
1513
+ def get_output_embeddings(self):
1514
+ """Returns the output embeddings (LM head)."""
1515
+ return self.lm_head
1516
+
1517
+ def set_output_embeddings(self, new_embeddings):
1518
+ """Sets the output embeddings layer."""
1519
+ self.lm_head = new_embeddings
1520
+
1521
+ def set_decoder(self, decoder):
1522
+ """Sets the ERNIE decoder model."""
1523
+ self.model = decoder
1524
+
1525
+ def get_decoder(self):
1526
+ """Get the transformer decoder."""
1527
+ return self.model
1528
+
1529
+ @can_return_tuple
1530
+ def forward(
1531
+ self,
1532
+ input_ids,
1533
+ attention_mask=None,
1534
+ position_ids=None,
1535
+ past_key_values: Optional[list[torch.FloatTensor]] = None,
1536
+ inputs_embeds=None,
1537
+ labels=None,
1538
+ loss_mask=None,
1539
+ use_cache=False,
1540
+ output_attentions: Optional[bool] = None,
1541
+ output_hidden_states: Optional[bool] = None,
1542
+ **kwargs: Unpack[KwargsForCausalLM],
1543
+ ):
1544
+ """
1545
+ Forward pass for causal language modeling.
1546
+ """
1547
+ output_attentions = (
1548
+ output_attentions
1549
+ if output_attentions is not None
1550
+ else self.config.output_attentions
1551
+ )
1552
+ output_hidden_states = (
1553
+ output_hidden_states
1554
+ if output_hidden_states is not None
1555
+ else self.config.output_hidden_states
1556
+ )
1557
+
1558
+ outputs = self.model(
1559
+ input_ids,
1560
+ position_ids=position_ids,
1561
+ attention_mask=attention_mask,
1562
+ inputs_embeds=inputs_embeds,
1563
+ use_cache=use_cache,
1564
+ past_key_values=past_key_values,
1565
+ output_attentions=output_attentions,
1566
+ output_hidden_states=output_hidden_states,
1567
+ **kwargs,
1568
+ )
1569
+
1570
+ hidden_states = outputs.last_hidden_state
1571
+ logits = self.lm_head(hidden_states)
1572
+
1573
+ loss, router_loss = None, None
1574
+ if getattr(self.config, "use_moe", False):
1575
+ router_loss = outputs.router_loss
1576
+
1577
+ if labels is not None:
1578
+ loss, _ = self.loss_function(logits, labels, loss_mask, router_loss)
1579
+
1580
+ return Ernie4_5_MoeCausalLMOutputWithPast(
1581
+ loss=loss,
1582
+ logits=logits,
1583
+ past_key_values=outputs.past_key_values,
1584
+ hidden_states=outputs.hidden_states,
1585
+ attentions=outputs.attentions,
1586
+ router_loss=router_loss,
1587
+ )
1588
+
1589
+
1590
+ __all__ = ["Ernie4_5_Model", "Ernie4_5_MoeForCausalLM", "Ernie4_5_PretrainedModel"]
special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"bos_token": "<s>", "eos_token": "</s>", "pad_token": "<unk>", "unk_token": "<unk>", "cls_token": "<|begin_of_sentence|>", "sep_token": "<|end_of_sentence|>", "mask_token": "<mask:1>", "sys_start_token": "<mask:4>", "sys_end_token": "<mask:5>", "header_start_token": "<mask:6>", "header_end_token": "<mask:7>", "additional_special_tokens": ["<|IMAGE_PLACEHOLDER|>", "<|AUDIO_PLACEHOLDER|>", "<|LOC_0|>", "<|LOC_1|>", "<|LOC_2|>", "<|LOC_3|>", "<|LOC_4|>", "<|LOC_5|>", "<|LOC_6|>", "<|LOC_7|>", "<|LOC_8|>", "<|LOC_9|>", "<|LOC_10|>", "<|LOC_11|>", "<|LOC_12|>", "<|LOC_13|>", "<|LOC_14|>", "<|LOC_15|>", "<|LOC_16|>", "<|LOC_17|>", "<|LOC_18|>", "<|LOC_19|>", "<|LOC_20|>", "<|LOC_21|>", "<|LOC_22|>", "<|LOC_23|>", "<|LOC_24|>", "<|LOC_25|>", "<|LOC_26|>", "<|LOC_27|>", "<|LOC_28|>", "<|LOC_29|>", "<|LOC_30|>", "<|LOC_31|>", "<|LOC_32|>", "<|LOC_33|>", "<|LOC_34|>", "<|LOC_35|>", "<|LOC_36|>", "<|LOC_37|>", "<|LOC_38|>", "<|LOC_39|>", "<|LOC_40|>", "<|LOC_41|>", "<|LOC_42|>", "<|LOC_43|>", "<|LOC_44|>", "<|LOC_45|>", "<|LOC_46|>", "<|LOC_47|>", "<|LOC_48|>", "<|LOC_49|>", "<|LOC_50|>", "<|LOC_51|>", "<|LOC_52|>", "<|LOC_53|>", "<|LOC_54|>", "<|LOC_55|>", "<|LOC_56|>", "<|LOC_57|>", "<|LOC_58|>", "<|LOC_59|>", "<|LOC_60|>", "<|LOC_61|>", "<|LOC_62|>", "<|LOC_63|>", "<|LOC_64|>", "<|LOC_65|>", "<|LOC_66|>", "<|LOC_67|>", "<|LOC_68|>", "<|LOC_69|>", "<|LOC_70|>", "<|LOC_71|>", "<|LOC_72|>", "<|LOC_73|>", "<|LOC_74|>", "<|LOC_75|>", "<|LOC_76|>", "<|LOC_77|>", "<|LOC_78|>", "<|LOC_79|>", "<|LOC_80|>", "<|LOC_81|>", "<|LOC_82|>", "<|LOC_83|>", "<|LOC_84|>", "<|LOC_85|>", "<|LOC_86|>", "<|LOC_87|>", "<|LOC_88|>", "<|LOC_89|>", "<|LOC_90|>", "<|LOC_91|>", "<|LOC_92|>", "<|LOC_93|>", "<|LOC_94|>", "<|LOC_95|>", "<|LOC_96|>", "<|LOC_97|>", "<|LOC_98|>", "<|LOC_99|>", "<|LOC_100|>", "<|LOC_101|>", "<|LOC_102|>", "<|LOC_103|>", "<|LOC_104|>", "<|LOC_105|>", "<|LOC_106|>", "<|LOC_107|>", "<|LOC_108|>", "<|LOC_109|>", "<|LOC_110|>", "<|LOC_111|>", "<|LOC_112|>", "<|LOC_113|>", "<|LOC_114|>", "<|LOC_115|>", "<|LOC_116|>", "<|LOC_117|>", "<|LOC_118|>", "<|LOC_119|>", "<|LOC_120|>", "<|LOC_121|>", "<|LOC_122|>", "<|LOC_123|>", "<|LOC_124|>", "<|LOC_125|>", "<|LOC_126|>", "<|LOC_127|>", "<|LOC_128|>", "<|LOC_129|>", "<|LOC_130|>", "<|LOC_131|>", "<|LOC_132|>", "<|LOC_133|>", "<|LOC_134|>", "<|LOC_135|>", "<|LOC_136|>", "<|LOC_137|>", "<|LOC_138|>", "<|LOC_139|>", "<|LOC_140|>", "<|LOC_141|>", "<|LOC_142|>", "<|LOC_143|>", "<|LOC_144|>", "<|LOC_145|>", "<|LOC_146|>", "<|LOC_147|>", "<|LOC_148|>", "<|LOC_149|>", "<|LOC_150|>", "<|LOC_151|>", "<|LOC_152|>", "<|LOC_153|>", "<|LOC_154|>", "<|LOC_155|>", "<|LOC_156|>", "<|LOC_157|>", "<|LOC_158|>", "<|LOC_159|>", "<|LOC_160|>", "<|LOC_161|>", "<|LOC_162|>", "<|LOC_163|>", "<|LOC_164|>", "<|LOC_165|>", "<|LOC_166|>", "<|LOC_167|>", "<|LOC_168|>", "<|LOC_169|>", "<|LOC_170|>", "<|LOC_171|>", "<|LOC_172|>", "<|LOC_173|>", "<|LOC_174|>", "<|LOC_175|>", "<|LOC_176|>", "<|LOC_177|>", "<|LOC_178|>", "<|LOC_179|>", "<|LOC_180|>", "<|LOC_181|>", "<|LOC_182|>", "<|LOC_183|>", "<|LOC_184|>", "<|LOC_185|>", "<|LOC_186|>", "<|LOC_187|>", "<|LOC_188|>", "<|LOC_189|>", "<|LOC_190|>", "<|LOC_191|>", "<|LOC_192|>", "<|LOC_193|>", "<|LOC_194|>", "<|LOC_195|>", "<|LOC_196|>", "<|LOC_197|>", "<|LOC_198|>", "<|LOC_199|>", "<|LOC_200|>", "<|LOC_201|>", "<|LOC_202|>", "<|LOC_203|>", "<|LOC_204|>", "<|LOC_205|>", "<|LOC_206|>", "<|LOC_207|>", "<|LOC_208|>", "<|LOC_209|>", "<|LOC_210|>", "<|LOC_211|>", "<|LOC_212|>", "<|LOC_213|>", "<|LOC_214|>", "<|LOC_215|>", "<|LOC_216|>", "<|LOC_217|>", "<|LOC_218|>", "<|LOC_219|>", "<|LOC_220|>", "<|LOC_221|>", "<|LOC_222|>", "<|LOC_223|>", "<|LOC_224|>", "<|LOC_225|>", "<|LOC_226|>", "<|LOC_227|>", "<|LOC_228|>", "<|LOC_229|>", "<|LOC_230|>", "<|LOC_231|>", "<|LOC_232|>", "<|LOC_233|>", "<|LOC_234|>", "<|LOC_235|>", "<|LOC_236|>", "<|LOC_237|>", "<|LOC_238|>", "<|LOC_239|>", "<|LOC_240|>", "<|LOC_241|>", "<|LOC_242|>", "<|LOC_243|>", "<|LOC_244|>", "<|LOC_245|>", "<|LOC_246|>", "<|LOC_247|>", "<|LOC_248|>", "<|LOC_249|>", "<|LOC_250|>", "<|LOC_251|>", "<|LOC_252|>", "<|LOC_253|>", "<|LOC_254|>", "<|LOC_255|>", "<|LOC_256|>", "<|LOC_257|>", "<|LOC_258|>", "<|LOC_259|>", "<|LOC_260|>", "<|LOC_261|>", "<|LOC_262|>", "<|LOC_263|>", "<|LOC_264|>", "<|LOC_265|>", "<|LOC_266|>", "<|LOC_267|>", "<|LOC_268|>", "<|LOC_269|>", "<|LOC_270|>", "<|LOC_271|>", "<|LOC_272|>", "<|LOC_273|>", "<|LOC_274|>", "<|LOC_275|>", "<|LOC_276|>", "<|LOC_277|>", "<|LOC_278|>", "<|LOC_279|>", "<|LOC_280|>", "<|LOC_281|>", "<|LOC_282|>", "<|LOC_283|>", "<|LOC_284|>", "<|LOC_285|>", "<|LOC_286|>", "<|LOC_287|>", "<|LOC_288|>", "<|LOC_289|>", "<|LOC_290|>", "<|LOC_291|>", "<|LOC_292|>", "<|LOC_293|>", "<|LOC_294|>", "<|LOC_295|>", "<|LOC_296|>", "<|LOC_297|>", "<|LOC_298|>", "<|LOC_299|>", "<|LOC_300|>", "<|LOC_301|>", "<|LOC_302|>", "<|LOC_303|>", "<|LOC_304|>", "<|LOC_305|>", "<|LOC_306|>", "<|LOC_307|>", "<|LOC_308|>", "<|LOC_309|>", "<|LOC_310|>", "<|LOC_311|>", "<|LOC_312|>", "<|LOC_313|>", "<|LOC_314|>", "<|LOC_315|>", "<|LOC_316|>", "<|LOC_317|>", "<|LOC_318|>", "<|LOC_319|>", "<|LOC_320|>", "<|LOC_321|>", "<|LOC_322|>", "<|LOC_323|>", "<|LOC_324|>", "<|LOC_325|>", "<|LOC_326|>", "<|LOC_327|>", "<|LOC_328|>", "<|LOC_329|>", "<|LOC_330|>", "<|LOC_331|>", "<|LOC_332|>", "<|LOC_333|>", "<|LOC_334|>", "<|LOC_335|>", "<|LOC_336|>", "<|LOC_337|>", "<|LOC_338|>", "<|LOC_339|>", "<|LOC_340|>", "<|LOC_341|>", "<|LOC_342|>", "<|LOC_343|>", "<|LOC_344|>", "<|LOC_345|>", "<|LOC_346|>", "<|LOC_347|>", "<|LOC_348|>", "<|LOC_349|>", "<|LOC_350|>", "<|LOC_351|>", "<|LOC_352|>", "<|LOC_353|>", "<|LOC_354|>", "<|LOC_355|>", "<|LOC_356|>", "<|LOC_357|>", "<|LOC_358|>", "<|LOC_359|>", "<|LOC_360|>", "<|LOC_361|>", "<|LOC_362|>", "<|LOC_363|>", "<|LOC_364|>", "<|LOC_365|>", "<|LOC_366|>", "<|LOC_367|>", "<|LOC_368|>", "<|LOC_369|>", "<|LOC_370|>", "<|LOC_371|>", "<|LOC_372|>", "<|LOC_373|>", "<|LOC_374|>", "<|LOC_375|>", "<|LOC_376|>", "<|LOC_377|>", "<|LOC_378|>", "<|LOC_379|>", "<|LOC_380|>", "<|LOC_381|>", "<|LOC_382|>", "<|LOC_383|>", "<|LOC_384|>", "<|LOC_385|>", "<|LOC_386|>", "<|LOC_387|>", "<|LOC_388|>", "<|LOC_389|>", "<|LOC_390|>", "<|LOC_391|>", "<|LOC_392|>", "<|LOC_393|>", "<|LOC_394|>", "<|LOC_395|>", "<|LOC_396|>", "<|LOC_397|>", "<|LOC_398|>", "<|LOC_399|>", "<|LOC_400|>", "<|LOC_401|>", "<|LOC_402|>", "<|LOC_403|>", "<|LOC_404|>", "<|LOC_405|>", "<|LOC_406|>", "<|LOC_407|>", "<|LOC_408|>", "<|LOC_409|>", "<|LOC_410|>", "<|LOC_411|>", "<|LOC_412|>", "<|LOC_413|>", "<|LOC_414|>", "<|LOC_415|>", "<|LOC_416|>", "<|LOC_417|>", "<|LOC_418|>", "<|LOC_419|>", "<|LOC_420|>", "<|LOC_421|>", "<|LOC_422|>", "<|LOC_423|>", "<|LOC_424|>", "<|LOC_425|>", "<|LOC_426|>", "<|LOC_427|>", "<|LOC_428|>", "<|LOC_429|>", "<|LOC_430|>", "<|LOC_431|>", "<|LOC_432|>", "<|LOC_433|>", "<|LOC_434|>", "<|LOC_435|>", "<|LOC_436|>", "<|LOC_437|>", "<|LOC_438|>", "<|LOC_439|>", "<|LOC_440|>", "<|LOC_441|>", "<|LOC_442|>", "<|LOC_443|>", "<|LOC_444|>", "<|LOC_445|>", "<|LOC_446|>", "<|LOC_447|>", "<|LOC_448|>", "<|LOC_449|>", "<|LOC_450|>", "<|LOC_451|>", "<|LOC_452|>", "<|LOC_453|>", "<|LOC_454|>", "<|LOC_455|>", "<|LOC_456|>", "<|LOC_457|>", "<|LOC_458|>", "<|LOC_459|>", "<|LOC_460|>", "<|LOC_461|>", "<|LOC_462|>", "<|LOC_463|>", "<|LOC_464|>", "<|LOC_465|>", "<|LOC_466|>", "<|LOC_467|>", "<|LOC_468|>", "<|LOC_469|>", "<|LOC_470|>", "<|LOC_471|>", "<|LOC_472|>", "<|LOC_473|>", "<|LOC_474|>", "<|LOC_475|>", "<|LOC_476|>", "<|LOC_477|>", "<|LOC_478|>", "<|LOC_479|>", "<|LOC_480|>", "<|LOC_481|>", "<|LOC_482|>", "<|LOC_483|>", "<|LOC_484|>", "<|LOC_485|>", "<|LOC_486|>", "<|LOC_487|>", "<|LOC_488|>", "<|LOC_489|>", "<|LOC_490|>", "<|LOC_491|>", "<|LOC_492|>", "<|LOC_493|>", "<|LOC_494|>", "<|LOC_495|>", "<|LOC_496|>", "<|LOC_497|>", "<|LOC_498|>", "<|LOC_499|>", "<|LOC_500|>", "<|LOC_501|>", "<|LOC_502|>", "<|LOC_503|>", "<|LOC_504|>", "<|LOC_505|>", "<|LOC_506|>", "<|LOC_507|>", "<|LOC_508|>", "<|LOC_509|>", "<|LOC_510|>", "<|LOC_511|>", "<|LOC_512|>", "<|LOC_513|>", "<|LOC_514|>", "<|LOC_515|>", "<|LOC_516|>", "<|LOC_517|>", "<|LOC_518|>", "<|LOC_519|>", "<|LOC_520|>", "<|LOC_521|>", "<|LOC_522|>", "<|LOC_523|>", "<|LOC_524|>", "<|LOC_525|>", "<|LOC_526|>", "<|LOC_527|>", "<|LOC_528|>", "<|LOC_529|>", "<|LOC_530|>", "<|LOC_531|>", "<|LOC_532|>", "<|LOC_533|>", "<|LOC_534|>", "<|LOC_535|>", "<|LOC_536|>", "<|LOC_537|>", "<|LOC_538|>", "<|LOC_539|>", "<|LOC_540|>", "<|LOC_541|>", "<|LOC_542|>", "<|LOC_543|>", "<|LOC_544|>", "<|LOC_545|>", "<|LOC_546|>", "<|LOC_547|>", "<|LOC_548|>", "<|LOC_549|>", "<|LOC_550|>", "<|LOC_551|>", "<|LOC_552|>", "<|LOC_553|>", "<|LOC_554|>", "<|LOC_555|>", "<|LOC_556|>", "<|LOC_557|>", "<|LOC_558|>", "<|LOC_559|>", "<|LOC_560|>", "<|LOC_561|>", "<|LOC_562|>", "<|LOC_563|>", "<|LOC_564|>", "<|LOC_565|>", "<|LOC_566|>", "<|LOC_567|>", "<|LOC_568|>", "<|LOC_569|>", "<|LOC_570|>", "<|LOC_571|>", "<|LOC_572|>", "<|LOC_573|>", "<|LOC_574|>", "<|LOC_575|>", "<|LOC_576|>", "<|LOC_577|>", "<|LOC_578|>", "<|LOC_579|>", "<|LOC_580|>", "<|LOC_581|>", "<|LOC_582|>", "<|LOC_583|>", "<|LOC_584|>", "<|LOC_585|>", "<|LOC_586|>", "<|LOC_587|>", "<|LOC_588|>", "<|LOC_589|>", "<|LOC_590|>", "<|LOC_591|>", "<|LOC_592|>", "<|LOC_593|>", "<|LOC_594|>", "<|LOC_595|>", "<|LOC_596|>", "<|LOC_597|>", "<|LOC_598|>", "<|LOC_599|>", "<|LOC_600|>", "<|LOC_601|>", "<|LOC_602|>", "<|LOC_603|>", "<|LOC_604|>", "<|LOC_605|>", "<|LOC_606|>", "<|LOC_607|>", "<|LOC_608|>", "<|LOC_609|>", "<|LOC_610|>", "<|LOC_611|>", "<|LOC_612|>", "<|LOC_613|>", "<|LOC_614|>", "<|LOC_615|>", "<|LOC_616|>", "<|LOC_617|>", "<|LOC_618|>", "<|LOC_619|>", "<|LOC_620|>", "<|LOC_621|>", "<|LOC_622|>", "<|LOC_623|>", "<|LOC_624|>", "<|LOC_625|>", "<|LOC_626|>", "<|LOC_627|>", "<|LOC_628|>", "<|LOC_629|>", "<|LOC_630|>", "<|LOC_631|>", "<|LOC_632|>", "<|LOC_633|>", "<|LOC_634|>", "<|LOC_635|>", "<|LOC_636|>", "<|LOC_637|>", "<|LOC_638|>", "<|LOC_639|>", "<|LOC_640|>", "<|LOC_641|>", "<|LOC_642|>", "<|LOC_643|>", "<|LOC_644|>", "<|LOC_645|>", "<|LOC_646|>", "<|LOC_647|>", "<|LOC_648|>", "<|LOC_649|>", "<|LOC_650|>", "<|LOC_651|>", "<|LOC_652|>", "<|LOC_653|>", "<|LOC_654|>", "<|LOC_655|>", "<|LOC_656|>", "<|LOC_657|>", "<|LOC_658|>", "<|LOC_659|>", "<|LOC_660|>", "<|LOC_661|>", "<|LOC_662|>", "<|LOC_663|>", "<|LOC_664|>", "<|LOC_665|>", "<|LOC_666|>", "<|LOC_667|>", "<|LOC_668|>", "<|LOC_669|>", "<|LOC_670|>", "<|LOC_671|>", "<|LOC_672|>", "<|LOC_673|>", "<|LOC_674|>", "<|LOC_675|>", "<|LOC_676|>", "<|LOC_677|>", "<|LOC_678|>", "<|LOC_679|>", "<|LOC_680|>", "<|LOC_681|>", "<|LOC_682|>", "<|LOC_683|>", "<|LOC_684|>", "<|LOC_685|>", "<|LOC_686|>", "<|LOC_687|>", "<|LOC_688|>", "<|LOC_689|>", "<|LOC_690|>", "<|LOC_691|>", "<|LOC_692|>", "<|LOC_693|>", "<|LOC_694|>", "<|LOC_695|>", "<|LOC_696|>", "<|LOC_697|>", "<|LOC_698|>", "<|LOC_699|>", "<|LOC_700|>", "<|LOC_701|>", "<|LOC_702|>", "<|LOC_703|>", "<|LOC_704|>", "<|LOC_705|>", "<|LOC_706|>", "<|LOC_707|>", "<|LOC_708|>", "<|LOC_709|>", "<|LOC_710|>", "<|LOC_711|>", "<|LOC_712|>", "<|LOC_713|>", "<|LOC_714|>", "<|LOC_715|>", "<|LOC_716|>", "<|LOC_717|>", "<|LOC_718|>", "<|LOC_719|>", "<|LOC_720|>", "<|LOC_721|>", "<|LOC_722|>", "<|LOC_723|>", "<|LOC_724|>", "<|LOC_725|>", "<|LOC_726|>", "<|LOC_727|>", "<|LOC_728|>", "<|LOC_729|>", "<|LOC_730|>", "<|LOC_731|>", "<|LOC_732|>", "<|LOC_733|>", "<|LOC_734|>", "<|LOC_735|>", "<|LOC_736|>", "<|LOC_737|>", "<|LOC_738|>", "<|LOC_739|>", "<|LOC_740|>", "<|LOC_741|>", "<|LOC_742|>", "<|LOC_743|>", "<|LOC_744|>", "<|LOC_745|>", "<|LOC_746|>", "<|LOC_747|>", "<|LOC_748|>", "<|LOC_749|>", "<|LOC_750|>", "<|LOC_751|>", "<|LOC_752|>", "<|LOC_753|>", "<|LOC_754|>", "<|LOC_755|>", "<|LOC_756|>", "<|LOC_757|>", "<|LOC_758|>", "<|LOC_759|>", "<|LOC_760|>", "<|LOC_761|>", "<|LOC_762|>", "<|LOC_763|>", "<|LOC_764|>", "<|LOC_765|>", "<|LOC_766|>", "<|LOC_767|>", "<|LOC_768|>", "<|LOC_769|>", "<|LOC_770|>", "<|LOC_771|>", "<|LOC_772|>", "<|LOC_773|>", "<|LOC_774|>", "<|LOC_775|>", "<|LOC_776|>", "<|LOC_777|>", "<|LOC_778|>", "<|LOC_779|>", "<|LOC_780|>", "<|LOC_781|>", "<|LOC_782|>", "<|LOC_783|>", "<|LOC_784|>", "<|LOC_785|>", "<|LOC_786|>", "<|LOC_787|>", "<|LOC_788|>", "<|LOC_789|>", "<|LOC_790|>", "<|LOC_791|>", "<|LOC_792|>", "<|LOC_793|>", "<|LOC_794|>", "<|LOC_795|>", "<|LOC_796|>", "<|LOC_797|>", "<|LOC_798|>", "<|LOC_799|>", "<|LOC_800|>", "<|LOC_801|>", "<|LOC_802|>", "<|LOC_803|>", "<|LOC_804|>", "<|LOC_805|>", "<|LOC_806|>", "<|LOC_807|>", "<|LOC_808|>", "<|LOC_809|>", "<|LOC_810|>", "<|LOC_811|>", "<|LOC_812|>", "<|LOC_813|>", "<|LOC_814|>", "<|LOC_815|>", "<|LOC_816|>", "<|LOC_817|>", "<|LOC_818|>", "<|LOC_819|>", "<|LOC_820|>", "<|LOC_821|>", "<|LOC_822|>", "<|LOC_823|>", "<|LOC_824|>", "<|LOC_825|>", "<|LOC_826|>", "<|LOC_827|>", "<|LOC_828|>", "<|LOC_829|>", "<|LOC_830|>", "<|LOC_831|>", "<|LOC_832|>", "<|LOC_833|>", "<|LOC_834|>", "<|LOC_835|>", "<|LOC_836|>", "<|LOC_837|>", "<|LOC_838|>", "<|LOC_839|>", "<|LOC_840|>", "<|LOC_841|>", "<|LOC_842|>", "<|LOC_843|>", "<|LOC_844|>", "<|LOC_845|>", "<|LOC_846|>", "<|LOC_847|>", "<|LOC_848|>", "<|LOC_849|>", "<|LOC_850|>", "<|LOC_851|>", "<|LOC_852|>", "<|LOC_853|>", "<|LOC_854|>", "<|LOC_855|>", "<|LOC_856|>", "<|LOC_857|>", "<|LOC_858|>", "<|LOC_859|>", "<|LOC_860|>", "<|LOC_861|>", "<|LOC_862|>", "<|LOC_863|>", "<|LOC_864|>", "<|LOC_865|>", "<|LOC_866|>", "<|LOC_867|>", "<|LOC_868|>", "<|LOC_869|>", "<|LOC_870|>", "<|LOC_871|>", "<|LOC_872|>", "<|LOC_873|>", "<|LOC_874|>", "<|LOC_875|>", "<|LOC_876|>", "<|LOC_877|>", "<|LOC_878|>", "<|LOC_879|>", "<|LOC_880|>", "<|LOC_881|>", "<|LOC_882|>", "<|LOC_883|>", "<|LOC_884|>", "<|LOC_885|>", "<|LOC_886|>", "<|LOC_887|>", "<|LOC_888|>", "<|LOC_889|>", "<|LOC_890|>", "<|LOC_891|>", "<|LOC_892|>", "<|LOC_893|>", "<|LOC_894|>", "<|LOC_895|>", "<|LOC_896|>", "<|LOC_897|>", "<|LOC_898|>", "<|LOC_899|>", "<|LOC_900|>", "<|LOC_901|>", "<|LOC_902|>", "<|LOC_903|>", "<|LOC_904|>", "<|LOC_905|>", "<|LOC_906|>", "<|LOC_907|>", "<|LOC_908|>", "<|LOC_909|>", "<|LOC_910|>", "<|LOC_911|>", "<|LOC_912|>", "<|LOC_913|>", "<|LOC_914|>", "<|LOC_915|>", "<|LOC_916|>", "<|LOC_917|>", "<|LOC_918|>", "<|LOC_919|>", "<|LOC_920|>", "<|LOC_921|>", "<|LOC_922|>", "<|LOC_923|>", "<|LOC_924|>", "<|LOC_925|>", "<|LOC_926|>", "<|LOC_927|>", "<|LOC_928|>", "<|LOC_929|>", "<|LOC_930|>", "<|LOC_931|>", "<|LOC_932|>", "<|LOC_933|>", "<|LOC_934|>", "<|LOC_935|>", "<|LOC_936|>", "<|LOC_937|>", "<|LOC_938|>", "<|LOC_939|>", "<|LOC_940|>", "<|LOC_941|>", "<|LOC_942|>", "<|LOC_943|>", "<|LOC_944|>", "<|LOC_945|>", "<|LOC_946|>", "<|LOC_947|>", "<|LOC_948|>", "<|LOC_949|>", "<|LOC_950|>", "<|LOC_951|>", "<|LOC_952|>", "<|LOC_953|>", "<|LOC_954|>", "<|LOC_955|>", "<|LOC_956|>", "<|LOC_957|>", "<|LOC_958|>", "<|LOC_959|>", "<|LOC_960|>", "<|LOC_961|>", "<|LOC_962|>", "<|LOC_963|>", "<|LOC_964|>", "<|LOC_965|>", "<|LOC_966|>", "<|LOC_967|>", "<|LOC_968|>", "<|LOC_969|>", "<|LOC_970|>", "<|LOC_971|>", "<|LOC_972|>", "<|LOC_973|>", "<|LOC_974|>", "<|LOC_975|>", "<|LOC_976|>", "<|LOC_977|>", "<|LOC_978|>", "<|LOC_979|>", "<|LOC_980|>", "<|LOC_981|>", "<|LOC_982|>", "<|LOC_983|>", "<|LOC_984|>", "<|LOC_985|>", "<|LOC_986|>", "<|LOC_987|>", "<|LOC_988|>", "<|LOC_989|>", "<|LOC_990|>", "<|LOC_991|>", "<|LOC_992|>", "<|LOC_993|>", "<|LOC_994|>", "<|LOC_995|>", "<|LOC_996|>", "<|LOC_997|>", "<|LOC_998|>", "<|LOC_999|>", "<|LOC_1000|>", "<|LOC_BEGIN|>", "<|LOC_END|>", "<|LOC_SEP|>", "<|CROP_COL_SEP|>", "<|CROP_ROW_SEP|>", "<|IMAGE_SEP|>"]}
tokenization_ernie4_5.py ADDED
@@ -0,0 +1,353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Ernie4_5_Tokenizer"""
15
+
16
+ import os
17
+ from shutil import copyfile
18
+ from typing import Dict, List, Optional, Tuple, Union
19
+ import torch
20
+ import numpy as np
21
+ import sentencepiece as spm
22
+
23
+ from transformers.tokenization_utils import PreTrainedTokenizer
24
+ from transformers.tokenization_utils_base import (
25
+ PaddingStrategy,
26
+ )
27
+ from transformers.utils import logging
28
+
29
+
30
+ logger = logging.get_logger(__name__)
31
+
32
+
33
+
34
+ class Ernie4_5_Tokenizer(PreTrainedTokenizer):
35
+ """
36
+ Ernie4_5_Tokenizer
37
+ vocab_files_names (dict): Mapping vocabulary-related config name to actual filename.
38
+ model_input_names (List): Model input names expected by the tokenizer
39
+ padding_side (str): Padding side (where to add padding tokens)
40
+ """
41
+ vocab_files_names = {
42
+ "vocab_file": "tokenizer.model",
43
+ }
44
+ model_input_names = ["input_ids", "position_ids", "attention_mask", "labels"]
45
+ padding_side = "right"
46
+
47
+ def __init__(
48
+ self,
49
+ vocab_file,
50
+ bos_token="<s>",
51
+ cls_token="<cls>",
52
+ eos_token="</s>",
53
+ mask_token="<mask:0>",
54
+ pad_token="<pad>",
55
+ sep_token="<sep>",
56
+ unk_token="<unk>",
57
+ additional_special_tokens=None,
58
+ split_special_tokens=False,
59
+ tokenizer_alpha=None,
60
+ **kwargs
61
+ ):
62
+ """
63
+ Initialize the ERNIE tokenizer.
64
+
65
+ Args:
66
+ vocab_file (str): Path to the SentencePiece model file.
67
+ bos_token (str, optional): Beginning of sentence token. Defaults to "<s>".
68
+ cls_token (str, optional): Classification token. Defaults to "<cls>".
69
+ eos_token (str, optional): End of sentence token. Defaults to "</s>".
70
+ mask_token (str, optional): Mask token. Defaults to "<mask:0>".
71
+ pad_token (str, optional): Padding token. Defaults to "<pad>".
72
+ sep_token (str, optional): Separator token. Defaults to "<sep>".
73
+ unk_token (str, optional): Unknown token. Defaults to "<unk>".
74
+ additional_special_tokens (List[str], optional): Additional special tokens.
75
+ Defaults to ["<mask:1>", "<mask:7>"].
76
+ split_special_tokens (bool, optional): Whether to split special tokens. Defaults to False.
77
+ tokenizer_alpha (float, optional): Alpha parameter for SentencePiece sampling.
78
+ **kwargs: Additional keyword arguments passed to the parent class.
79
+ """
80
+
81
+ self.vocab_file = vocab_file
82
+ self.sp_model = spm.SentencePieceProcessor()
83
+ self.sp_model.Load(vocab_file)
84
+ self.pad_id = self._convert_token_to_id(pad_token)
85
+ self.tokenizer_alpha = tokenizer_alpha
86
+
87
+ if additional_special_tokens is None:
88
+ additional_special_tokens = ["<mask:1>", "<mask:7>"]
89
+ super().__init__(
90
+ bos_token=bos_token,
91
+ cls_token=cls_token,
92
+ eos_token=eos_token,
93
+ mask_token=mask_token,
94
+ pad_token=pad_token,
95
+ sep_token=sep_token,
96
+ unk_token=unk_token,
97
+ additional_special_tokens=additional_special_tokens,
98
+ split_special_tokens=split_special_tokens,
99
+ **kwargs,
100
+ )
101
+
102
+ @property
103
+ def vocab_size(self):
104
+ """Returns the size of the vocabulary.
105
+
106
+ Returns:
107
+ int: The number of tokens in the vocabulary.
108
+ """
109
+ return self.sp_model.vocab_size()
110
+
111
+ def get_vocab(self):
112
+ """Get the vocabulary as a dictionary mapping tokens to their IDs.
113
+
114
+ Returns:
115
+ dict: A dictionary mapping tokens to their corresponding IDs.
116
+ """
117
+ vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
118
+ vocab.update(self.added_tokens_encoder)
119
+ return vocab
120
+
121
+ def _tokenize(self, text):
122
+ """Tokenize text using SentencePiece.
123
+
124
+ Args:
125
+ text (str): The text to tokenize.
126
+
127
+ Returns:
128
+ list: A list of tokens.
129
+ """
130
+ if self.tokenizer_alpha is not None:
131
+ return self.sp_model.encode_as_pieces(
132
+ text,
133
+ enable_sampling=True,
134
+ nbest_size=-1,
135
+ alpha=self.tokenizer_alpha,
136
+ )
137
+ else:
138
+ return self.sp_model.encode_as_pieces(text)
139
+
140
+ def _convert_token_to_id(self, token):
141
+ """Convert a token (str) to an ID using the vocabulary.
142
+
143
+ Args:
144
+ token (str): The token to convert.
145
+
146
+ Returns:
147
+ int: The corresponding token ID.
148
+ """
149
+ return self.sp_model.piece_to_id(token)
150
+
151
+ def _convert_id_to_token(self, id):
152
+ """Convert an ID to a token (str) using the vocabulary.
153
+
154
+ Args:
155
+ id (int): The token ID to convert.
156
+
157
+ Returns:
158
+ str: The corresponding token.
159
+ """
160
+ if id >= self.vocab_size:
161
+ return self.unk_token
162
+ else:
163
+ return self.sp_model.id_to_piece(id)
164
+
165
+ def convert_tokens_to_string(self, tokens):
166
+ """Convert a sequence of tokens back to a single string.
167
+
168
+ Args:
169
+ tokens (List[str]): A list of tokens to convert.
170
+
171
+ Returns:
172
+ str: The reconstructed string.
173
+ """
174
+ current_sub_tokens = []
175
+ out_string = ""
176
+ prev_is_special = False
177
+ for token in tokens:
178
+ # make sure that special tokens are not decoded using sentencepiece model
179
+ if token in self.all_special_tokens:
180
+ if not prev_is_special:
181
+ out_string += " "
182
+ out_string += self.sp_model.decode(current_sub_tokens) + token
183
+ prev_is_special = True
184
+ current_sub_tokens = []
185
+ else:
186
+ current_sub_tokens.append(token)
187
+ prev_is_special = False
188
+ out_string += self.sp_model.decode(current_sub_tokens)
189
+ return out_string
190
+
191
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
192
+ """Build model inputs by adding special tokens to sequences.
193
+
194
+ Args:
195
+ token_ids_0 (List[int]): List of token IDs for the first sequence.
196
+ token_ids_1 (List[int], optional): List of token IDs for the second sequence.
197
+
198
+ Returns:
199
+ List[int]: List of token IDs with special tokens added.
200
+ """
201
+ output = token_ids_0
202
+ last_cls_index = -1
203
+ last_sep_index = -1
204
+ if self.cls_token_id in output:
205
+ last_cls_index = len(output) - output[::-1].index(self.cls_token_id) - 1
206
+ if self.sep_token_id in output:
207
+ last_sep_index = len(output) - output[::-1].index(self.sep_token_id) - 1
208
+
209
+ if last_cls_index > last_sep_index:
210
+ next_token_id = self.sep_token_id
211
+ elif last_sep_index > last_cls_index:
212
+ next_token_id = self.cls_token_id
213
+ else:
214
+ output = [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
215
+ next_token_id = self.cls_token_id
216
+
217
+ output = [self.bos_token_id] + output
218
+ # Assume no markup in text if token_ids_1 is given.
219
+ if token_ids_1 is not None:
220
+ output = output + token_ids_1 + [next_token_id]
221
+ return output
222
+
223
+ def get_special_tokens_mask(self, token_ids_0, token_ids_1=None, already_has_special_tokens=False):
224
+ """Get a mask showing which tokens are special tokens.
225
+
226
+ Args:
227
+ token_ids_0 (List[int]): List of token IDs for the first sequence.
228
+ token_ids_1 (List[int], optional): List of token IDs for the second sequence.
229
+ already_has_special_tokens (bool): Whether the tokens already include special tokens.
230
+
231
+ Returns:
232
+ List[int]: A mask where 1 indicates special tokens and 0 indicates regular tokens.
233
+ """
234
+ if already_has_special_tokens:
235
+ return super().get_special_tokens_mask(token_ids_0, token_ids_1, already_has_special_tokens=True)
236
+
237
+ # [bos_token, cls_token, tokens_0, sep_token]
238
+ if token_ids_1 is None:
239
+ return [1, 1] + ([0] * len(token_ids_0)) + [1]
240
+ # [bos_token, cls_token, tokens_0, sep_token, tokens_1, cls_token]
241
+ return [1, 1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
242
+
243
+ def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> Tuple[str]:
244
+ """
245
+ Save the vocabulary and special tokens file to a directory.
246
+
247
+ Args:
248
+ save_directory (str): The directory in which to save the vocabulary.
249
+ filename_prefix (Optional[str]): Optional prefix for the saved filename.
250
+
251
+ Returns:
252
+ Tuple[str]: Paths to the files saved.
253
+
254
+ Raises:
255
+ ValueError: If the save_directory is not a valid directory.
256
+ """
257
+ if not os.path.isdir(save_directory):
258
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
259
+ return
260
+ out_vocab_file = os.path.join(
261
+ save_directory,
262
+ (filename_prefix + "-" if filename_prefix else "") + self.resource_files_names["vocab_file"],
263
+ )
264
+
265
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
266
+ copyfile(self.vocab_file, out_vocab_file)
267
+ elif not os.path.isfile(self.vocab_file):
268
+ with open(out_vocab_file, "wb") as fi:
269
+ content_spiece_model = self.sp_model.serialized_model_proto()
270
+ fi.write(content_spiece_model)
271
+
272
+ return (out_vocab_file,)
273
+
274
+ def _pad(
275
+ self,
276
+ encoded_inputs: Union[Dict],
277
+ max_length: Optional[int] = None,
278
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
279
+ pad_to_multiple_of: Optional[int] = None,
280
+ padding_side: Optional[str] = None,
281
+ return_attention_mask: Optional[bool] = None,
282
+ ) -> dict:
283
+ """
284
+ Pad encoded inputs according to specified strategy.
285
+
286
+ Args:
287
+ encoded_inputs (Union[Dict]): Dictionary of encoded inputs.
288
+ max_length (Optional[int]): Maximum length to pad to.
289
+ padding_strategy (PaddingStrategy): Strategy for padding.
290
+ pad_to_multiple_of (Optional[int]): Pad to a multiple of this value.
291
+ return_attention_mask (Optional[bool]): Whether to return attention mask.
292
+
293
+ Returns:
294
+ dict: Dictionary with padded inputs and optional attention mask.
295
+
296
+ Raises:
297
+ ValueError: If attention_mask has unexpected type or invalid padding strategy.
298
+ """
299
+ if return_attention_mask is None:
300
+ return_attention_mask = "attention_mask" in self.model_input_names
301
+ if return_attention_mask:
302
+ required_input = encoded_inputs[self.model_input_names[0]]
303
+ if padding_strategy == PaddingStrategy.LONGEST:
304
+ max_length = len(required_input)
305
+ if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
306
+ max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
307
+ needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
308
+
309
+ if "attention_mask" in encoded_inputs and encoded_inputs["attention_mask"] is not None:
310
+ attention_mask = encoded_inputs.pop("attention_mask")
311
+ if isinstance(attention_mask, torch.Tensor):
312
+ attention_mask = attention_mask.numpy()
313
+ elif isinstance(attention_mask, list):
314
+ attention_mask = np.array(attention_mask)
315
+ elif not isinstance(attention_mask, np.ndarray):
316
+ raise ValueError(f"Unexpected type {type(attention_mask)} of attention_mask, ")
317
+ else:
318
+ # Create default attention mask if none provided
319
+ attention_mask = np.tril(np.ones((len(required_input), len(required_input)), dtype=np.int64))
320
+ attention_mask = np.expand_dims(attention_mask, axis=0)
321
+
322
+ if needs_to_be_padded:
323
+ difference = max_length - len(required_input)
324
+ if self.padding_side == "right":
325
+ if attention_mask.ndim == 1:
326
+ pad_width = [(0, difference)]
327
+ else:
328
+ pad_width = [(0, 0), (0, difference), (0, difference)]
329
+ elif self.padding_side == "left":
330
+ if attention_mask.ndim == 1:
331
+ pad_width = [(difference, 0)]
332
+ else:
333
+ pad_width = [(0, 0), (difference, 0), (difference, 0)]
334
+ else:
335
+ raise ValueError("Invalid padding strategy:" + str(self.padding_side))
336
+ attention_mask = np.pad(
337
+ attention_mask,
338
+ pad_width=pad_width,
339
+ mode="constant",
340
+ constant_values=0,
341
+ )
342
+
343
+ encoded_inputs = super()._pad(
344
+ encoded_inputs,
345
+ max_length,
346
+ padding_strategy=padding_strategy,
347
+ pad_to_multiple_of=pad_to_multiple_of,
348
+ return_attention_mask=False,
349
+ )
350
+ if return_attention_mask:
351
+ encoded_inputs["attention_mask"] = attention_mask.tolist()
352
+ return encoded_inputs
353
+
tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:34ef7db83df785924fb83d7b887b6e822a031c56e15cff40aaf9b982988180df
3
+ size 1614363
tokenizer_config.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "<s>",
3
+ "eos_token": "</s>",
4
+ "pad_token": "<unk>",
5
+ "unk_token": "<unk>",
6
+ "cls_token": "<|begin_of_sentence|>",
7
+ "sep_token": "<|end_of_sentence|>",
8
+ "mask_token": "<mask:1>",
9
+ "sys_start_token": "<mask:4>",
10
+ "sys_end_token": "<mask:5>",
11
+ "header_start_token": "<mask:6>",
12
+ "header_end_token": "<mask:7>",
13
+ "additional_special_tokens": null,
14
+ "chat_template": "{%- if not add_generation_prompt is defined -%}\n {%- set add_generation_prompt = true -%}\n{%- endif -%}\n{%- if not cls_token is defined -%}\n {%- set cls_token = \"<|begin_of_sentence|>\" -%}\n{%- endif -%}\n{%- if not sep_token is defined -%}\n {%- set sep_token = \"<|end_of_sentence|>\" -%}\n{%- endif -%}\n{{- cls_token -}}\n{%- for message in messages -%}\n {%- if message[\"role\"] == \"user\" -%}\n {{- \"User: \" + message[\"content\"] + \"\n\" -}}\n {%- elif message[\"role\"] == \"assistant\" -%}\n {{- \"Assistant: \" + message[\"content\"] + sep_token -}}\n {%- elif message[\"role\"] == \"system\" -%}\n {{- message[\"content\"] + \"\n\" -}}\n {%- endif -%}\n{%- endfor -%}\n{%- if add_generation_prompt -%}\n {{- \"Assistant: \" -}}\n{%- endif -%}",
15
+ "tokenizer_class": "Ernie4_5_Tokenizer",
16
+ "auto_map": {
17
+ "AutoTokenizer": [
18
+ "tokenization_ernie4_5.Ernie4_5_Tokenizer",
19
+ "tokenization_ernie4_5.Ernie4_5_Tokenizer"
20
+ ]
21
+ }
22
+ }