nitsanluke commited on
Commit
dac8a18
·
0 Parent(s):

initial commit

Browse files
.gitattributes ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.png filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ pipeline_tag: text-generation
4
+ library_name: transformers
5
+ track_downloads: true
6
+ ---
7
+
8
+ # Apriel-H1-15b-Thinker
9
+
10
+ [![Use in Transformers](https://img.shields.io/badge/%F0%9F%A4%97%20Use%20in%20Transformers-Docs-5C5CFF)](https://huggingface.co/docs/transformers/index)
11
+
12
+ <img src="images/apriel_h.png" width="120" alt="thumbnail"/> `/ˈɑː.pri.əl/`
13
+
14
+ A 15B-parameter **hybrid reasoning model** combining Transformer attention and Mamba State Space layers for high efficiency and scalability. Derived from *Apriel-Nemotron-15B-Thinker* through progressive distillation, **Apriel-H1** replaces less critical attention layers with linear Mamba blocks—achieving **over 2× higher inference throughput** in vLLM with **minimal loss in reasoning, math, and coding performance**.
15
+
16
+ - **Model Size:** 15B parameters
17
+ - **Context Length:** 65K (target; runtime dependent)
18
+ - **Languages:** English (best)
19
+
20
+ ## Highlights
21
+
22
+ - Hybrid Transformer–SSM architecture
23
+ - ~2× throughput improvement over the base Thinker model
24
+ - Retains strong reasoning, math, and coding capabilities
25
+ - Built via efficient distillation—no training from scratch required
26
+
27
+
28
+ ## Model Overview
29
+ Apriel-H1-15b-Thinker is designed for agentic tasks, code assistance, and multi-step reasoning. It follows Apriel’s “think then answer” style: the model first produces a hidden chain-of-thought and then a concise final response. Where reasoning traces are undesired, configure prompts to favor concise outputs.
30
+
31
+ **Technical report**: <a href="https://github.com/ServiceNow/apriel/blob/main/assets/Apriel-H1.pdf" target="_blank" rel="noopener noreferrer">Apriel-H1 Report</a>
32
+
33
+ ### Efficient and strong among hybrids
34
+ ![Throughput 1->16K](./images/throughput_eval_score_vs_throughput_1-16k_annotated.png)
35
+
36
+ All models were evaluated with vllm server endpoints using FlashInfer (except for AI21-Jamba-Reasoning-3B which used FlashAttention2), mamba_cache was set to fp32 for models: NVIDIA-Nemotron-Nano-9B-v2 and AI21-Jamba-Reasoning-3B.
37
+
38
+ #### Comparing with Thinker ~2x speedup!
39
+ <img src="images/apriel_h_vs_apriel_15b_eval_thrput_comparison.png" width='auto'>
40
+
41
+
42
+ ## How to Use
43
+
44
+ Install dependencies:
45
+
46
+ ```bash
47
+ pip install transformers==4.53.2
48
+ ```
49
+
50
+ Basic usage with Transformers generate:
51
+
52
+ ```python
53
+ import re
54
+ from transformers import AutoModelForCausalLM, AutoTokenizer
55
+
56
+ model_name = "ServiceNow-AI/Apriel-H1-15b-Thinker-SFT"
57
+
58
+ # load the tokenizer and the model
59
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
60
+ model = AutoModelForCausalLM.from_pretrained(
61
+ model_name,
62
+ torch_dtype="auto",
63
+ device_map="auto"
64
+ )
65
+
66
+ prompt = "Positive real numbers $x$ and $y$ satisfy $y^3=x^2$ and $(y-x)^2=4y^2$. What is $x+y$?\nMark your solution with \\boxed"
67
+ messages = [
68
+ {"role": "user", "content": prompt}
69
+ ]
70
+
71
+ text = tokenizer.apply_chat_template(
72
+ messages,
73
+ tokenize=False,
74
+ add_generation_prompt=True,
75
+ tools=[]
76
+ )
77
+ model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
78
+
79
+ generated_ids = model.generate(**model_inputs, max_new_tokens=1024)
80
+ output = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
81
+
82
+ response = re.findall(r"\\[BEGIN FINAL RESPONSE\\](.*?)\\[END FINAL RESPONSE\\]", output, re.DOTALL)[0].strip()
83
+ print("response:", response)
84
+ ```
85
+
86
+ Recommended settings: temperature 0.6; increase `max_new_tokens` for complex reasoning.
87
+
88
+ ## Use it with vLLM
89
+
90
+ ### 💻 Local Installation
91
+
92
+ #### 1. Create and activate a Python environment
93
+ You can use any environment manager. The example below uses [`uv`](https://github.com/astral-sh/uv):
94
+
95
+ ```bash
96
+ uv venv --python 3.12 --seed
97
+ source .venv/bin/activate
98
+ ```
99
+
100
+ #### 2. Install vLLM and the Apriel plugin
101
+
102
+ Find our plugin at [https://github.com/ServiceNow/apriel](https://github.com/ServiceNow/apriel).
103
+ You may need to install a version of **vLLM** compatible with your **CUDA** version.
104
+
105
+ In this example, we use the default CUDA version and let vLLM automatically select the correct backend.
106
+
107
+ ```bash
108
+ git clone [email protected]:ServiceNow/apriel.git
109
+ cd apriel
110
+ uv pip install vllm==0.10.2 --torch-backend=auto
111
+ pip install .
112
+ ```
113
+
114
+ <!--
115
+ ### 🐳 Use Prebuilt Docker Image
116
+
117
+ For convenience, prebuilt Docker images are available from **GitHub Container Registry (GHCR):**
118
+
119
+ **Repository:**
120
+ [https://github.com/ServiceNow/apriel](https://github.com/ServiceNow/apriel)
121
+
122
+ **Container packages:**
123
+ [https://github.com/ServiceNow/apriel/pkgs/container/apriel](https://github.com/ServiceNow/apriel/pkgs/container/apriel)
124
+
125
+ Pull the latest version or a specific tagged build (e.g., commit SHA):
126
+
127
+ ```bash
128
+ docker pull ghcr.io/servicenow/apriel:latest
129
+ # or a specific digest
130
+ docker pull ghcr.io/servicenow/apriel:sha-e41528d
131
+
132
+ ```
133
+ -->
134
+ ### 🧠 Running a vLLM Server
135
+
136
+ ### Option 1: Run locally (from source install)
137
+
138
+ Once installed, you can launch a vLLM OpenAI-compatible API server with your Apriel model:
139
+
140
+ ```bash
141
+ vllm serve \
142
+ --model ServiceNow-AI/Apriel-H1-15b-Thinker-SFT \
143
+ --port 8000
144
+ ```
145
+
146
+ #### Option 2: Run via Docker
147
+
148
+ You can run the server directly using the prebuilt container:
149
+
150
+ ```bash
151
+ docker run --runtime nvidia --gpus all \
152
+ -v ~/.cache/huggingface:/root/.cache/huggingface \
153
+ --env "HUGGING_FACE_HUB_TOKEN=$HF_TOKEN" \
154
+ -p 8000:8000 \
155
+ --ipc=host \
156
+ ghcr.io/servicenow/apriel:latest \
157
+ --model ServiceNow-AI/Apriel-H1-15b-Thinker-SFT \
158
+ ```
159
+
160
+ ## Chat Template
161
+
162
+ ```
163
+ <|system|>
164
+ You are a thoughtful and systematic AI assistant built by ServiceNow Language Models (SLAM) lab. Before providing an answer, analyze the problem carefully and present your reasoning step by step. After explaining your thought process, provide the final solution in the following format: [BEGIN FINAL RESPONSE] ... [END FINAL RESPONSE].
165
+ <|end|>
166
+ <|user|>
167
+ # user message here
168
+ <|end|>
169
+ <|assistant|>
170
+ Here are my reasoning steps:
171
+ # thoughts here
172
+ [BEGIN FINAL RESPONSE]
173
+ # assistant response here
174
+ [END FINAL RESPONSE]
175
+ <|end|>
176
+ ```
177
+ The model will first generate its thinking process and then generate its final response between `[BEGIN FINAL RESPONSE]` and `[END FINAL RESPONSE]`. Here is a code snippet demonstrating the application of the chat template:
178
+
179
+
180
+
181
+ ```python
182
+ from transformers import AutoTokenizer
183
+ model_name = "ServiceNow-AI/Apriel-H1-15b-Thinker-SFT"
184
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
185
+
186
+ # prepare the model input
187
+ custom_system_prompt = "Answer like a pirate."
188
+ prompt = "You are an expert assistant in the implementation of customer experience management aspect of retail applications \n \nYou will be using Python as the programming language. \n \nYou will utilize a factory design pattern for the implementation and following the dependency inversion principle \n \nYou will modify the implementation based on user requirements. \n \nUpon user request, you will add, update, and remove the features & enhancements in the implementation provided by you. \n \nYou will ask whether the user wants to refactor the provided code or needs a sample implementation for reference. Upon user confirmation, I will proceed accordingly. \n \n**Guidelines:** \n 1. **User Requirements:** \n - You have to ask users about their requirements, clarify the user expectations, and suggest the best possible solution by providing examples of Python code snippets. \n - Ask users about which type of reports they need to assess the AI model's performance, accuracy, and reliability. \n - After providing the solution, you have to ask the user about the trial of the solution and modify the solution based on the user feedback. \n \n 2. **Libraries/Frameworks:** \n - You will be utilizing Python as a programming language. \n - You will be using Flask framework for REST APIS implementation \n \n 3. **Communication Gesture:** \n - Your conversation with the user should be interactive, supportive, courageous, and professional. \n - You have to break down the complex concepts into sub-concepts and try to explain them to the user. \n - You have to ask the user for the required parameters. If the user refuses to provide in 2 attempts, politely exit the conversation. \n - You have to provide your supported parameters to the user, if the user refuses to accept them then you have to put an apology note and exit the conversation. \n - You have to track the conversation about unasked questions by the user. If some/one of the questions remain then you have to remind the user about these questions and proceed to answer them based on the user's confirmation \n \n 4. **Implementation:** \n - Your code/implementations should be reliable, scalable, modular, and reusable. \n - You will be providing unit tests for the implementation upon user request. \n - You will be following MVC architecture for the applications \n - Your implementations must be well-commented and readable \n \n \n- Today's date is 23rd August 2024. \n- The default sender email is [email protected].\nHi, I am conducting research on retail customer feedback systems and I need assistance with designing and implementing them. Could you kindly provide me with a list of general customer feedback system modules?"
189
+ messages = [
190
+ {"role": "user", "content": custom_system_prompt + "\n\n" + prompt}
191
+ ]
192
+ # example tools
193
+ tools = [{"type": "function", "function": {"name": "getRetailFeedbackModules", "description": "Returns the list of modules usually present in the retail industry", "parameters": {"type": "object", "properties": {"page": {"type": "integer", "description": "The current page number.", "default": 1}, "page_size": {"type": "integer", "description": "The number of items per page.", "default": 3}}}}}, {"type": "function", "function": {"name": "verifyImplementation", "description": "Returns the list of modules usually present in the retail industry", "parameters": {"type": "object", "properties": {"coding_language": {"type": "string", "description": "The supported languages for verification of implementation.", "default": "python", "enum": ["python", "java", "php"]}, "code": {"type": "string", "description": "The code which needs verification"}, "design_pattern": {"type": "string", "description": "The design pattern to verify in the implementation", "enum": ["factory", "strategy", "singleton"]}, "verify_best_practices": {"type": "boolean", "description": "The verification of the coding style based on the language selected", "default": true}}}}}]
194
+ text = tokenizer.apply_chat_template(
195
+ messages,
196
+ tokenize=False,
197
+ add_generation_prompt=True,
198
+ tools=tools
199
+ )
200
+ model_inputs = tokenizer([text], return_tensors="pt")
201
+ ```
202
+
203
+ ### Usage Guidelines
204
+ 1. Use the model’s default chat template, which already includes a system prompt. We recommend adding all other instructions within the user message.
205
+ 2. We recommend setting temperature to `0.6`.
206
+ 3. We ensure the model starts with `Here are my reasoning steps:\n` during all our evaluations. This is implemented in the default chat template.
207
+
208
+
209
+
210
+ ## Intended Use
211
+
212
+ The Apriel family of models are designed for a variety of general-purpose instruction tasks, including:
213
+
214
+ - Code assistance and generation
215
+ - Logical reasoning and multi-step tasks
216
+ - Question answering and information retrieval
217
+ - Function calling, complex instruction following and agent use cases
218
+
219
+ They are **not intended** for use in safety-critical applications without human oversight or in scenarios requiring guaranteed factual accuracy.
220
+
221
+ ---
222
+
223
+ ## Limitations
224
+
225
+ - **Factual accuracy:** May produce incorrect, misleading, or outdated content. Outputs should be verified before use in critical contexts.
226
+ - **Bias:** May reflect societal, cultural, or systemic biases present in training data.
227
+ - **Ethics:** Do not use the model to produce harmful, unlawful, or unethical content.
228
+ - **Language:** Strongest performance is in English. Output quality may degrade in underrepresented languages.
229
+ - **Critical use:** Not suitable for medical, legal, financial, or other high-risk applications without safeguards.
230
+
231
+ ---
232
+
233
+ ## Security and Responsible Use
234
+
235
+ **Security Responsibilities:**
236
+ Deployers and users are strongly encouraged to align their security practices with established frameworks and regulatory guidelines such as the EU AI Act and the NIST AI Risk Management Framework (RMF).
237
+
238
+ **Guidelines for Deployers:**
239
+
240
+ - Regularly conduct robustness assessments to identify and mitigate adversarial inputs.
241
+ - Implement validation and filtering processes to prevent harmful or biased outputs.
242
+ - Continuously perform data privacy checks to guard against unintended data leaks.
243
+ - Document and communicate the model's limitations, intended usage, and known security risks to all end-users.
244
+ - Schedule periodic security reviews and updates to address emerging threats and vulnerabilities.
245
+
246
+ **Guidelines for Users:**
247
+
248
+ - Follow established security policies and usage guidelines provided by deployers.
249
+ - Protect and manage sensitive information when interacting with the model.
250
+ - Report anomalies, suspicious behavior, or unsafe outputs to deployers or developers.
251
+ - Maintain human oversight and apply judgment to mitigate potential security or ethical risks during interactions.
252
+
253
+ **Disclaimer:**
254
+ Users accept responsibility for securely deploying, managing, and using this open-source LLM. The model is provided "as-is," without explicit or implied warranty regarding security or fitness for any specific application or environment.
255
+
256
+ ---
257
+
258
+ ## Software
259
+
260
+ - **Training stack:** [Fast-LLM](https://github.com/ServiceNow/Fast-LLM)
261
+
262
+ ---
263
+
264
+ ## License
265
+
266
+ MIT
267
+
268
+ ---
269
+
270
+ ## Citation
271
+
272
+ ```bibtex
273
+ @misc{apriel_h1_2025,
274
+ title = {Apriel-H1: Towards Efficient Enterprise Reasoning Models},
275
+ author = {ServiceNow Language Models Lab},
276
+ howpublished = {https://huggingface.co/ServiceNow-AI/Apriel-H1-15b-Thinker-SFT},
277
+ year = {2025}
278
+ }
279
+ ```
config.json ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "AprielHForCausalLM"
4
+ ],
5
+ "attention_dropout": 0.0,
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_apriel_h.AprielHConfig",
8
+ "AutoModel": "modeling_apriel_h.AprielHModel",
9
+ "AutoModelForCausalLM": "modeling_apriel_h.AprielHForCausalLM"
10
+ },
11
+ "bos_token_id": 1,
12
+ "eos_token_id": 2,
13
+ "head_dim": 128,
14
+ "hidden_act": "silu",
15
+ "hidden_size": 5120,
16
+ "hybrid_block_layout": [
17
+ "t",
18
+ "t",
19
+ "t",
20
+ "m2",
21
+ "t",
22
+ "m2",
23
+ "m2",
24
+ "m2",
25
+ "m2",
26
+ "t",
27
+ "t",
28
+ "t",
29
+ "t",
30
+ "t",
31
+ "t",
32
+ "t",
33
+ "m2",
34
+ "t",
35
+ "m2",
36
+ "m2",
37
+ "m2",
38
+ "m2",
39
+ "m2",
40
+ "m2",
41
+ "m2",
42
+ "m2",
43
+ "m2",
44
+ "m2",
45
+ "t",
46
+ "t",
47
+ "m2",
48
+ "m2",
49
+ "m2",
50
+ "m2",
51
+ "t",
52
+ "m2",
53
+ "m2",
54
+ "m2",
55
+ "m2",
56
+ "m2",
57
+ "m2",
58
+ "m2",
59
+ "m2",
60
+ "m2",
61
+ "m2",
62
+ "m2",
63
+ "m2",
64
+ "m2",
65
+ "t",
66
+ "m2"
67
+ ],
68
+ "initializer_range": 0.02,
69
+ "intermediate_size": 14336,
70
+ "max_position_embeddings": 65536,
71
+ "model_type": "apriel_h",
72
+ "num_attention_heads": 32,
73
+ "num_hidden_layers": 50,
74
+ "num_key_value_heads": 8,
75
+ "rms_norm_eps": 1e-05,
76
+ "rope_scaling": {
77
+ "rope_type": "default"
78
+ },
79
+ "rope_theta": 1000000.0,
80
+ "sliding_window": null,
81
+ "ssm_cfg": {
82
+ "activation": "silu",
83
+ "bias": false,
84
+ "chunk_size": 128,
85
+ "conv_bias": true,
86
+ "d_conv": 4,
87
+ "d_inner": 4096,
88
+ "d_state": 16,
89
+ "d_xb": 1024,
90
+ "dt_init": "random",
91
+ "dt_init_floor": 0.0001,
92
+ "dt_max": 0.1,
93
+ "dt_min": 0.001,
94
+ "dt_rank": 320,
95
+ "dt_scale": 1.0,
96
+ "expand": 1,
97
+ "n_qk_heads": 32,
98
+ "n_v_heads": 32
99
+ },
100
+ "tie_word_embeddings": false,
101
+ "transformers_version": "4.53.2",
102
+ "use_cache": true,
103
+ "vocab_size": 131072
104
+ }
configuration_apriel_h.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import MistralConfig
2
+ from transformers.utils import logging
3
+
4
+ logger = logging.get_logger(__name__)
5
+
6
+ ssm_config_default = {
7
+ "d_state": 64,
8
+ "n_qk_heads": 32,
9
+ "expand": 1,
10
+ "chunk_size": 128,
11
+ "activation": "identity",
12
+ "bias": False,
13
+ "d_conv": 4,
14
+ "d_inner": 32 * 128,
15
+ "d_xb": None, # will be set to model dim
16
+ "dt_rank": "auto",
17
+ "dt_min": 0.001,
18
+ "dt_max": 0.1,
19
+ "dt_init": "random",
20
+ "dt_scale": 1.0,
21
+ "dt_init_floor": 1e-4,
22
+ "conv_bias": True,
23
+ }
24
+
25
+
26
+ class AprielHConfig(MistralConfig):
27
+ model_type = "apriel_h"
28
+
29
+ def __init__(self, hybrid_block_layout=["m2"], ssm_cfg=None, **kwargs):
30
+ super().__init__(**kwargs)
31
+ self.hybrid_block_layout = hybrid_block_layout
32
+ self.head_dim = self.head_dim or self.hidden_size // self.num_attention_heads # as in transformers 4.51.3
33
+ self.ssm_cfg = ssm_cfg or ssm_config_default
34
+
35
+ for k, v in ssm_config_default.items():
36
+ if k not in self.ssm_cfg:
37
+ self.ssm_cfg[k] = v # to make sure all elements are present in the config
images/apriel_h.png ADDED

Git LFS Details

  • SHA256: ebb37e2644bf76eeae0fd5c8e480a29775d2416e659ee460f52770aef23a635b
  • Pointer size: 131 Bytes
  • Size of remote file: 449 kB
images/apriel_h_vs_apriel_15b_eval_thrput_comparison.png ADDED

Git LFS Details

  • SHA256: 2d0440f4736cebb3d097532840dd4d7b492b513a22cede8f8789d2dbc6e20914
  • Pointer size: 131 Bytes
  • Size of remote file: 326 kB
images/throughput_eval_score_vs_throughput_1-16k_annotated.png ADDED

Git LFS Details

  • SHA256: 0726f12817837573939cb0e917ebaee9fb9a72410104df4d23afe91bc080e5fd
  • Pointer size: 131 Bytes
  • Size of remote file: 122 kB
model.safetensors.index.json ADDED
@@ -0,0 +1,915 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "fast_llm_metadata": {
4
+ "fast_llm_version": "0.2.0",
5
+ "model": "hybrid_ssm",
6
+ "format": "apriel_ssm_thinker_hybrid",
7
+ "config": {
8
+ "type": "hybrid_ssm",
9
+ "base_model": {
10
+ "transformer": {
11
+ "type": "lm_decoder",
12
+ "normalization": {
13
+ "type": "rms_norm",
14
+ "epsilon": 1e-05
15
+ },
16
+ "rotary": {
17
+ "type": "default",
18
+ "theta": 1000000.0
19
+ },
20
+ "peft": {
21
+ "type": "none"
22
+ },
23
+ "num_layers": 50,
24
+ "hidden_size": 5120,
25
+ "num_attention_heads": 32,
26
+ "head_groups": 8,
27
+ "add_linear_biases": false,
28
+ "ffn_hidden_size": 14336,
29
+ "kv_channels": 128,
30
+ "gated": true,
31
+ "activation_type": "silu",
32
+ "mlp_lr_scale": 1.0,
33
+ "attention_lr_scale": 1.0
34
+ },
35
+ "vision_encoder": {
36
+ "transformer": {
37
+ "normalization": {
38
+ "type": "layer_norm"
39
+ },
40
+ "rotary": {
41
+ "type": "none"
42
+ },
43
+ "peft": {
44
+ "type": "none"
45
+ }
46
+ },
47
+ "patch_norm": {
48
+ "type": "layer_norm"
49
+ }
50
+ },
51
+ "vocab_size": 131072,
52
+ "use_position_embeddings": false,
53
+ "tie_word_embeddings": false,
54
+ "cross_entropy_impl": "fused",
55
+ "distillation_loss_implementation": "reverse_kl",
56
+ "distillation_model": "teacher",
57
+ "parallel_embeddings": false,
58
+ "embeddings_lr_scale": 1.0,
59
+ "output_lr_scale": 1.0,
60
+ "ssm": {
61
+ "normalization": {
62
+ "type": "layer_norm"
63
+ },
64
+ "expansion_factor": 1,
65
+ "state_size": 16,
66
+ "conv_kernel_dimension": 4,
67
+ "dt_rank": 320,
68
+ "n_qk_heads": 32,
69
+ "n_v_heads": 32,
70
+ "d_inner": 4096,
71
+ "d_xb": 1024,
72
+ "add_bias_linear": false,
73
+ "activation_type": "silu",
74
+ "chunk_size": 128,
75
+ "dt_init": "random",
76
+ "dt_scale": 1.0,
77
+ "dt_min": 0.001,
78
+ "dt_max": 0.1,
79
+ "dt_init_floor": 0.0001
80
+ },
81
+ "hybrid_block_layout": [
82
+ "t",
83
+ "t",
84
+ "t",
85
+ "m2",
86
+ "t",
87
+ "m2",
88
+ "m2",
89
+ "m2",
90
+ "m2",
91
+ "t",
92
+ "t",
93
+ "t",
94
+ "t",
95
+ "t",
96
+ "t",
97
+ "t",
98
+ "m2",
99
+ "t",
100
+ "m2",
101
+ "m2",
102
+ "m2",
103
+ "m2",
104
+ "m2",
105
+ "m2",
106
+ "m2",
107
+ "m2",
108
+ "m2",
109
+ "m2",
110
+ "t",
111
+ "t",
112
+ "m2",
113
+ "m2",
114
+ "m2",
115
+ "m2",
116
+ "t",
117
+ "m2",
118
+ "m2",
119
+ "m2",
120
+ "m2",
121
+ "m2",
122
+ "m2",
123
+ "m2",
124
+ "m2",
125
+ "m2",
126
+ "m2",
127
+ "m2",
128
+ "m2",
129
+ "m2",
130
+ "t",
131
+ "m2"
132
+ ]
133
+ },
134
+ "multi_stage": {
135
+ "zero_stage": 3
136
+ },
137
+ "distributed": {
138
+ "tensor_parallel": 8,
139
+ "sequence_tensor_parallel": true,
140
+ "world_size": 64,
141
+ "rank": 0,
142
+ "local_world_size": 8,
143
+ "timeout": 36000.0,
144
+ "seed": 984060,
145
+ "training_dtype": "bfloat16"
146
+ }
147
+ },
148
+ "shards": [
149
+ "weights"
150
+ ],
151
+ "metadata": {
152
+ "optimizer": {
153
+ "current_step": 25000,
154
+ "grad_scaler": {
155
+ "type": "NoopGradScaler"
156
+ }
157
+ },
158
+ "completed_steps": 25000,
159
+ "metrics": {
160
+ "Training": {
161
+ "train_iters": 60000,
162
+ "batch_size": 64,
163
+ "iteration": 25000,
164
+ "distillation_loss": 0.034176599606871604,
165
+ "language_model_loss": 0.034176599606871604,
166
+ "consumed_samples": 1600000,
167
+ "consumed_tokens": 26214400000,
168
+ "step_time_ms": 13669.109502492938,
169
+ "step_time_average_ms": 14812.520303467,
170
+ "remaining_time": 518438.21062134503,
171
+ "completion_time": 1757381892.6039205,
172
+ "percent_done": 41.666666666666664,
173
+ "skipped_iters": 0,
174
+ "nan_iters": 0,
175
+ "model_tflops": 126.99100713435496,
176
+ "hardware_tflops": 131.01289176252138,
177
+ "tokens_per_sec_per_gpu": 1198.6150229473199,
178
+ "run": 0,
179
+ "grad_norm": 0.24984633922576904,
180
+ "learning_rate": 2.968135593220339e-06,
181
+ "loss_scale": 1.0,
182
+ "reserved": 55520.0,
183
+ "allocated": 17871.4111328125,
184
+ "max_allocated": 48712.857421875,
185
+ "max_reserved": 55520.0,
186
+ "global_max_reserved": 55520.0
187
+ }
188
+ }
189
+ }
190
+ },
191
+ "model_config": {
192
+ "model_type": "apriel_ssm_thinker_hybrid",
193
+ "architectures": [
194
+ "AprielThinkerSSMHybridForCausalLM"
195
+ ],
196
+ "rope_theta": 1000000.0,
197
+ "hidden_act": "silu",
198
+ "num_hidden_layers": 50,
199
+ "hidden_size": 5120,
200
+ "num_attention_heads": 32,
201
+ "num_key_value_heads": 8,
202
+ "intermediate_size": 14336,
203
+ "vocab_size": 131072,
204
+ "tie_word_embeddings": false,
205
+ "rms_norm_eps": 1e-05,
206
+ "head_dim": 128,
207
+ "rope_scaling": {
208
+ "rope_type": "default"
209
+ },
210
+ "ssm_cfg": {
211
+ "d_state": 16,
212
+ "n_v_heads": 32,
213
+ "n_qk_heads": 32,
214
+ "expand": 1,
215
+ "chunk_size": 128,
216
+ "bias": false,
217
+ "activation": "silu",
218
+ "dt_rank": 320,
219
+ "dt_min": 0.001,
220
+ "dt_max": 0.1,
221
+ "dt_init_floor": 0.0001,
222
+ "dt_scale": 1.0,
223
+ "d_xb": 1024,
224
+ "d_conv": 4,
225
+ "dt_init": "random",
226
+ "d_inner": 4096,
227
+ "conv_bias": true
228
+ },
229
+ "hybrid_block_layout": [
230
+ "t",
231
+ "t",
232
+ "t",
233
+ "m2",
234
+ "t",
235
+ "m2",
236
+ "m2",
237
+ "m2",
238
+ "m2",
239
+ "t",
240
+ "t",
241
+ "t",
242
+ "t",
243
+ "t",
244
+ "t",
245
+ "t",
246
+ "m2",
247
+ "t",
248
+ "m2",
249
+ "m2",
250
+ "m2",
251
+ "m2",
252
+ "m2",
253
+ "m2",
254
+ "m2",
255
+ "m2",
256
+ "m2",
257
+ "m2",
258
+ "t",
259
+ "t",
260
+ "m2",
261
+ "m2",
262
+ "m2",
263
+ "m2",
264
+ "t",
265
+ "m2",
266
+ "m2",
267
+ "m2",
268
+ "m2",
269
+ "m2",
270
+ "m2",
271
+ "m2",
272
+ "m2",
273
+ "m2",
274
+ "m2",
275
+ "m2",
276
+ "m2",
277
+ "m2",
278
+ "t",
279
+ "m2"
280
+ ],
281
+ "auto_map": {
282
+ "AutoConfig": "configuration_ssm_hybrid_apriel15b.AprielSSMHybridConfig",
283
+ "AutoModel": "modeling_ssm_hybrid_apriel15b.AprielThinkerSSMHybridModel",
284
+ "AutoModelForCausalLM": "modeling_ssm_hybrid_apriel15b.AprielThinkerSSMHybridForCausalLM"
285
+ },
286
+ "attn_implementation": null
287
+ },
288
+ "format": "pt"
289
+ },
290
+ "weight_map": {
291
+ "model.embed_tokens.weight": "model_0.safetensors",
292
+ "model.layers.0.input_layernorm.weight": "model_0.safetensors",
293
+ "model.layers.0.post_attention_layernorm.weight": "model_0.safetensors",
294
+ "model.layers.0.self_attn.q_proj.weight": "model_0.safetensors",
295
+ "model.layers.0.self_attn.k_proj.weight": "model_0.safetensors",
296
+ "model.layers.0.self_attn.v_proj.weight": "model_0.safetensors",
297
+ "model.layers.0.self_attn.o_proj.weight": "model_0.safetensors",
298
+ "model.layers.0.mlp.gate_proj.weight": "model_0.safetensors",
299
+ "model.layers.0.mlp.up_proj.weight": "model_0.safetensors",
300
+ "model.layers.0.mlp.down_proj.weight": "model_0.safetensors",
301
+ "model.layers.1.input_layernorm.weight": "model_0.safetensors",
302
+ "model.layers.1.post_attention_layernorm.weight": "model_0.safetensors",
303
+ "model.layers.1.self_attn.q_proj.weight": "model_0.safetensors",
304
+ "model.layers.1.self_attn.k_proj.weight": "model_0.safetensors",
305
+ "model.layers.1.self_attn.v_proj.weight": "model_0.safetensors",
306
+ "model.layers.1.self_attn.o_proj.weight": "model_0.safetensors",
307
+ "model.layers.1.mlp.gate_proj.weight": "model_0.safetensors",
308
+ "model.layers.1.mlp.up_proj.weight": "model_0.safetensors",
309
+ "model.layers.1.mlp.down_proj.weight": "model_0.safetensors",
310
+ "model.layers.2.input_layernorm.weight": "model_0.safetensors",
311
+ "model.layers.2.post_attention_layernorm.weight": "model_0.safetensors",
312
+ "model.layers.2.self_attn.q_proj.weight": "model_0.safetensors",
313
+ "model.layers.2.self_attn.k_proj.weight": "model_0.safetensors",
314
+ "model.layers.2.self_attn.v_proj.weight": "model_0.safetensors",
315
+ "model.layers.2.self_attn.o_proj.weight": "model_0.safetensors",
316
+ "model.layers.2.mlp.gate_proj.weight": "model_0.safetensors",
317
+ "model.layers.2.mlp.up_proj.weight": "model_0.safetensors",
318
+ "model.layers.2.mlp.down_proj.weight": "model_0.safetensors",
319
+ "model.layers.3.mixer.A_log": "model_0.safetensors",
320
+ "model.layers.3.mixer.D": "model_0.safetensors",
321
+ "model.layers.3.input_layernorm.weight": "model_0.safetensors",
322
+ "model.layers.3.post_attention_layernorm.weight": "model_0.safetensors",
323
+ "model.layers.3.mixer.dt_in_proj.weight": "model_0.safetensors",
324
+ "model.layers.3.mixer.conv1d.weight": "model_0.safetensors",
325
+ "model.layers.3.mixer.conv1d.bias": "model_0.safetensors",
326
+ "model.layers.3.mixer.dt_proj.bias": "model_0.safetensors",
327
+ "model.layers.3.mixer.in_proj.weight": "model_0.safetensors",
328
+ "model.layers.3.mixer.dt_proj.weight": "model_0.safetensors",
329
+ "model.layers.3.mixer.out_proj.weight": "model_0.safetensors",
330
+ "model.layers.3.mlp.gate_proj.weight": "model_0.safetensors",
331
+ "model.layers.3.mlp.up_proj.weight": "model_0.safetensors",
332
+ "model.layers.3.mlp.down_proj.weight": "model_0.safetensors",
333
+ "model.layers.4.input_layernorm.weight": "model_0.safetensors",
334
+ "model.layers.4.post_attention_layernorm.weight": "model_0.safetensors",
335
+ "model.layers.4.self_attn.q_proj.weight": "model_0.safetensors",
336
+ "model.layers.4.self_attn.k_proj.weight": "model_0.safetensors",
337
+ "model.layers.4.self_attn.v_proj.weight": "model_0.safetensors",
338
+ "model.layers.4.self_attn.o_proj.weight": "model_0.safetensors",
339
+ "model.layers.4.mlp.gate_proj.weight": "model_0.safetensors",
340
+ "model.layers.4.mlp.up_proj.weight": "model_0.safetensors",
341
+ "model.layers.4.mlp.down_proj.weight": "model_0.safetensors",
342
+ "model.layers.5.mixer.A_log": "model_0.safetensors",
343
+ "model.layers.5.mixer.D": "model_0.safetensors",
344
+ "model.layers.5.input_layernorm.weight": "model_0.safetensors",
345
+ "model.layers.5.post_attention_layernorm.weight": "model_0.safetensors",
346
+ "model.layers.5.mixer.dt_in_proj.weight": "model_0.safetensors",
347
+ "model.layers.5.mixer.conv1d.weight": "model_0.safetensors",
348
+ "model.layers.5.mixer.conv1d.bias": "model_0.safetensors",
349
+ "model.layers.5.mixer.dt_proj.bias": "model_0.safetensors",
350
+ "model.layers.5.mixer.in_proj.weight": "model_0.safetensors",
351
+ "model.layers.5.mixer.dt_proj.weight": "model_0.safetensors",
352
+ "model.layers.5.mixer.out_proj.weight": "model_0.safetensors",
353
+ "model.layers.5.mlp.gate_proj.weight": "model_0.safetensors",
354
+ "model.layers.5.mlp.up_proj.weight": "model_0.safetensors",
355
+ "model.layers.5.mlp.down_proj.weight": "model_0.safetensors",
356
+ "model.layers.6.mixer.A_log": "model_0.safetensors",
357
+ "model.layers.6.mixer.D": "model_0.safetensors",
358
+ "model.layers.6.input_layernorm.weight": "model_0.safetensors",
359
+ "model.layers.6.post_attention_layernorm.weight": "model_0.safetensors",
360
+ "model.layers.6.mixer.dt_in_proj.weight": "model_0.safetensors",
361
+ "model.layers.6.mixer.conv1d.weight": "model_0.safetensors",
362
+ "model.layers.6.mixer.conv1d.bias": "model_0.safetensors",
363
+ "model.layers.6.mixer.dt_proj.bias": "model_0.safetensors",
364
+ "model.layers.6.mixer.in_proj.weight": "model_0.safetensors",
365
+ "model.layers.6.mixer.dt_proj.weight": "model_0.safetensors",
366
+ "model.layers.6.mixer.out_proj.weight": "model_0.safetensors",
367
+ "model.layers.6.mlp.gate_proj.weight": "model_0.safetensors",
368
+ "model.layers.6.mlp.up_proj.weight": "model_0.safetensors",
369
+ "model.layers.6.mlp.down_proj.weight": "model_0.safetensors",
370
+ "model.layers.7.mixer.A_log": "model_0.safetensors",
371
+ "model.layers.7.mixer.D": "model_0.safetensors",
372
+ "model.layers.7.input_layernorm.weight": "model_0.safetensors",
373
+ "model.layers.7.post_attention_layernorm.weight": "model_0.safetensors",
374
+ "model.layers.7.mixer.dt_in_proj.weight": "model_0.safetensors",
375
+ "model.layers.7.mixer.conv1d.weight": "model_0.safetensors",
376
+ "model.layers.7.mixer.conv1d.bias": "model_0.safetensors",
377
+ "model.layers.7.mixer.dt_proj.bias": "model_0.safetensors",
378
+ "model.layers.7.mixer.in_proj.weight": "model_0.safetensors",
379
+ "model.layers.7.mixer.dt_proj.weight": "model_0.safetensors",
380
+ "model.layers.7.mixer.out_proj.weight": "model_0.safetensors",
381
+ "model.layers.7.mlp.gate_proj.weight": "model_0.safetensors",
382
+ "model.layers.7.mlp.up_proj.weight": "model_0.safetensors",
383
+ "model.layers.7.mlp.down_proj.weight": "model_0.safetensors",
384
+ "model.layers.8.mixer.A_log": "model_0.safetensors",
385
+ "model.layers.8.mixer.D": "model_0.safetensors",
386
+ "model.layers.8.input_layernorm.weight": "model_0.safetensors",
387
+ "model.layers.8.post_attention_layernorm.weight": "model_0.safetensors",
388
+ "model.layers.8.mixer.dt_in_proj.weight": "model_0.safetensors",
389
+ "model.layers.8.mixer.conv1d.weight": "model_0.safetensors",
390
+ "model.layers.8.mixer.conv1d.bias": "model_0.safetensors",
391
+ "model.layers.8.mixer.dt_proj.bias": "model_0.safetensors",
392
+ "model.layers.8.mixer.in_proj.weight": "model_0.safetensors",
393
+ "model.layers.8.mixer.dt_proj.weight": "model_0.safetensors",
394
+ "model.layers.8.mixer.out_proj.weight": "model_0.safetensors",
395
+ "model.layers.8.mlp.gate_proj.weight": "model_0.safetensors",
396
+ "model.layers.8.mlp.up_proj.weight": "model_0.safetensors",
397
+ "model.layers.8.mlp.down_proj.weight": "model_0.safetensors",
398
+ "model.layers.9.input_layernorm.weight": "model_0.safetensors",
399
+ "model.layers.9.post_attention_layernorm.weight": "model_0.safetensors",
400
+ "model.layers.9.self_attn.q_proj.weight": "model_0.safetensors",
401
+ "model.layers.9.self_attn.k_proj.weight": "model_0.safetensors",
402
+ "model.layers.9.self_attn.v_proj.weight": "model_0.safetensors",
403
+ "model.layers.9.self_attn.o_proj.weight": "model_0.safetensors",
404
+ "model.layers.9.mlp.gate_proj.weight": "model_0.safetensors",
405
+ "model.layers.9.mlp.up_proj.weight": "model_0.safetensors",
406
+ "model.layers.9.mlp.down_proj.weight": "model_0.safetensors",
407
+ "model.layers.10.input_layernorm.weight": "model_0.safetensors",
408
+ "model.layers.10.post_attention_layernorm.weight": "model_0.safetensors",
409
+ "model.layers.10.self_attn.q_proj.weight": "model_0.safetensors",
410
+ "model.layers.10.self_attn.k_proj.weight": "model_0.safetensors",
411
+ "model.layers.10.self_attn.v_proj.weight": "model_0.safetensors",
412
+ "model.layers.10.self_attn.o_proj.weight": "model_0.safetensors",
413
+ "model.layers.10.mlp.gate_proj.weight": "model_0.safetensors",
414
+ "model.layers.10.mlp.up_proj.weight": "model_0.safetensors",
415
+ "model.layers.10.mlp.down_proj.weight": "model_0.safetensors",
416
+ "model.layers.11.input_layernorm.weight": "model_0.safetensors",
417
+ "model.layers.11.post_attention_layernorm.weight": "model_0.safetensors",
418
+ "model.layers.11.self_attn.q_proj.weight": "model_0.safetensors",
419
+ "model.layers.11.self_attn.k_proj.weight": "model_0.safetensors",
420
+ "model.layers.11.self_attn.v_proj.weight": "model_0.safetensors",
421
+ "model.layers.11.self_attn.o_proj.weight": "model_0.safetensors",
422
+ "model.layers.11.mlp.gate_proj.weight": "model_0.safetensors",
423
+ "model.layers.11.mlp.up_proj.weight": "model_0.safetensors",
424
+ "model.layers.11.mlp.down_proj.weight": "model_0.safetensors",
425
+ "model.layers.12.input_layernorm.weight": "model_0.safetensors",
426
+ "model.layers.12.post_attention_layernorm.weight": "model_0.safetensors",
427
+ "model.layers.12.self_attn.q_proj.weight": "model_0.safetensors",
428
+ "model.layers.12.self_attn.k_proj.weight": "model_0.safetensors",
429
+ "model.layers.12.self_attn.v_proj.weight": "model_0.safetensors",
430
+ "model.layers.12.self_attn.o_proj.weight": "model_0.safetensors",
431
+ "model.layers.12.mlp.gate_proj.weight": "model_0.safetensors",
432
+ "model.layers.12.mlp.up_proj.weight": "model_0.safetensors",
433
+ "model.layers.12.mlp.down_proj.weight": "model_0.safetensors",
434
+ "model.layers.13.input_layernorm.weight": "model_1.safetensors",
435
+ "model.layers.13.post_attention_layernorm.weight": "model_1.safetensors",
436
+ "model.layers.13.self_attn.q_proj.weight": "model_1.safetensors",
437
+ "model.layers.13.self_attn.k_proj.weight": "model_1.safetensors",
438
+ "model.layers.13.self_attn.v_proj.weight": "model_1.safetensors",
439
+ "model.layers.13.self_attn.o_proj.weight": "model_1.safetensors",
440
+ "model.layers.13.mlp.gate_proj.weight": "model_1.safetensors",
441
+ "model.layers.13.mlp.up_proj.weight": "model_1.safetensors",
442
+ "model.layers.13.mlp.down_proj.weight": "model_1.safetensors",
443
+ "model.layers.14.input_layernorm.weight": "model_1.safetensors",
444
+ "model.layers.14.post_attention_layernorm.weight": "model_1.safetensors",
445
+ "model.layers.14.self_attn.q_proj.weight": "model_1.safetensors",
446
+ "model.layers.14.self_attn.k_proj.weight": "model_1.safetensors",
447
+ "model.layers.14.self_attn.v_proj.weight": "model_1.safetensors",
448
+ "model.layers.14.self_attn.o_proj.weight": "model_1.safetensors",
449
+ "model.layers.14.mlp.gate_proj.weight": "model_1.safetensors",
450
+ "model.layers.14.mlp.up_proj.weight": "model_1.safetensors",
451
+ "model.layers.14.mlp.down_proj.weight": "model_1.safetensors",
452
+ "model.layers.15.input_layernorm.weight": "model_1.safetensors",
453
+ "model.layers.15.post_attention_layernorm.weight": "model_1.safetensors",
454
+ "model.layers.15.self_attn.q_proj.weight": "model_1.safetensors",
455
+ "model.layers.15.self_attn.k_proj.weight": "model_1.safetensors",
456
+ "model.layers.15.self_attn.v_proj.weight": "model_1.safetensors",
457
+ "model.layers.15.self_attn.o_proj.weight": "model_1.safetensors",
458
+ "model.layers.15.mlp.gate_proj.weight": "model_1.safetensors",
459
+ "model.layers.15.mlp.up_proj.weight": "model_1.safetensors",
460
+ "model.layers.15.mlp.down_proj.weight": "model_1.safetensors",
461
+ "model.layers.16.mixer.A_log": "model_1.safetensors",
462
+ "model.layers.16.mixer.D": "model_1.safetensors",
463
+ "model.layers.16.input_layernorm.weight": "model_1.safetensors",
464
+ "model.layers.16.post_attention_layernorm.weight": "model_1.safetensors",
465
+ "model.layers.16.mixer.dt_in_proj.weight": "model_1.safetensors",
466
+ "model.layers.16.mixer.conv1d.weight": "model_1.safetensors",
467
+ "model.layers.16.mixer.conv1d.bias": "model_1.safetensors",
468
+ "model.layers.16.mixer.dt_proj.bias": "model_1.safetensors",
469
+ "model.layers.16.mixer.in_proj.weight": "model_1.safetensors",
470
+ "model.layers.16.mixer.dt_proj.weight": "model_1.safetensors",
471
+ "model.layers.16.mixer.out_proj.weight": "model_1.safetensors",
472
+ "model.layers.16.mlp.gate_proj.weight": "model_1.safetensors",
473
+ "model.layers.16.mlp.up_proj.weight": "model_1.safetensors",
474
+ "model.layers.16.mlp.down_proj.weight": "model_1.safetensors",
475
+ "model.layers.17.input_layernorm.weight": "model_1.safetensors",
476
+ "model.layers.17.post_attention_layernorm.weight": "model_1.safetensors",
477
+ "model.layers.17.self_attn.q_proj.weight": "model_1.safetensors",
478
+ "model.layers.17.self_attn.k_proj.weight": "model_1.safetensors",
479
+ "model.layers.17.self_attn.v_proj.weight": "model_1.safetensors",
480
+ "model.layers.17.self_attn.o_proj.weight": "model_1.safetensors",
481
+ "model.layers.17.mlp.gate_proj.weight": "model_1.safetensors",
482
+ "model.layers.17.mlp.up_proj.weight": "model_1.safetensors",
483
+ "model.layers.17.mlp.down_proj.weight": "model_1.safetensors",
484
+ "model.layers.18.mixer.A_log": "model_1.safetensors",
485
+ "model.layers.18.mixer.D": "model_1.safetensors",
486
+ "model.layers.18.input_layernorm.weight": "model_1.safetensors",
487
+ "model.layers.18.post_attention_layernorm.weight": "model_1.safetensors",
488
+ "model.layers.18.mixer.dt_in_proj.weight": "model_1.safetensors",
489
+ "model.layers.18.mixer.conv1d.weight": "model_1.safetensors",
490
+ "model.layers.18.mixer.conv1d.bias": "model_1.safetensors",
491
+ "model.layers.18.mixer.dt_proj.bias": "model_1.safetensors",
492
+ "model.layers.18.mixer.in_proj.weight": "model_1.safetensors",
493
+ "model.layers.18.mixer.dt_proj.weight": "model_1.safetensors",
494
+ "model.layers.18.mixer.out_proj.weight": "model_1.safetensors",
495
+ "model.layers.18.mlp.gate_proj.weight": "model_1.safetensors",
496
+ "model.layers.18.mlp.up_proj.weight": "model_1.safetensors",
497
+ "model.layers.18.mlp.down_proj.weight": "model_1.safetensors",
498
+ "model.layers.19.mixer.A_log": "model_1.safetensors",
499
+ "model.layers.19.mixer.D": "model_1.safetensors",
500
+ "model.layers.19.input_layernorm.weight": "model_1.safetensors",
501
+ "model.layers.19.post_attention_layernorm.weight": "model_1.safetensors",
502
+ "model.layers.19.mixer.dt_in_proj.weight": "model_1.safetensors",
503
+ "model.layers.19.mixer.conv1d.weight": "model_1.safetensors",
504
+ "model.layers.19.mixer.conv1d.bias": "model_1.safetensors",
505
+ "model.layers.19.mixer.dt_proj.bias": "model_1.safetensors",
506
+ "model.layers.19.mixer.in_proj.weight": "model_1.safetensors",
507
+ "model.layers.19.mixer.dt_proj.weight": "model_1.safetensors",
508
+ "model.layers.19.mixer.out_proj.weight": "model_1.safetensors",
509
+ "model.layers.19.mlp.gate_proj.weight": "model_1.safetensors",
510
+ "model.layers.19.mlp.up_proj.weight": "model_1.safetensors",
511
+ "model.layers.19.mlp.down_proj.weight": "model_1.safetensors",
512
+ "model.layers.20.mixer.A_log": "model_1.safetensors",
513
+ "model.layers.20.mixer.D": "model_1.safetensors",
514
+ "model.layers.20.input_layernorm.weight": "model_1.safetensors",
515
+ "model.layers.20.post_attention_layernorm.weight": "model_1.safetensors",
516
+ "model.layers.20.mixer.dt_in_proj.weight": "model_1.safetensors",
517
+ "model.layers.20.mixer.conv1d.weight": "model_1.safetensors",
518
+ "model.layers.20.mixer.conv1d.bias": "model_1.safetensors",
519
+ "model.layers.20.mixer.dt_proj.bias": "model_1.safetensors",
520
+ "model.layers.20.mixer.in_proj.weight": "model_1.safetensors",
521
+ "model.layers.20.mixer.dt_proj.weight": "model_1.safetensors",
522
+ "model.layers.20.mixer.out_proj.weight": "model_1.safetensors",
523
+ "model.layers.20.mlp.gate_proj.weight": "model_1.safetensors",
524
+ "model.layers.20.mlp.up_proj.weight": "model_1.safetensors",
525
+ "model.layers.20.mlp.down_proj.weight": "model_1.safetensors",
526
+ "model.layers.21.mixer.A_log": "model_1.safetensors",
527
+ "model.layers.21.mixer.D": "model_1.safetensors",
528
+ "model.layers.21.input_layernorm.weight": "model_1.safetensors",
529
+ "model.layers.21.post_attention_layernorm.weight": "model_1.safetensors",
530
+ "model.layers.21.mixer.dt_in_proj.weight": "model_1.safetensors",
531
+ "model.layers.21.mixer.conv1d.weight": "model_1.safetensors",
532
+ "model.layers.21.mixer.conv1d.bias": "model_1.safetensors",
533
+ "model.layers.21.mixer.dt_proj.bias": "model_1.safetensors",
534
+ "model.layers.21.mixer.in_proj.weight": "model_1.safetensors",
535
+ "model.layers.21.mixer.dt_proj.weight": "model_1.safetensors",
536
+ "model.layers.21.mixer.out_proj.weight": "model_1.safetensors",
537
+ "model.layers.21.mlp.gate_proj.weight": "model_1.safetensors",
538
+ "model.layers.21.mlp.up_proj.weight": "model_1.safetensors",
539
+ "model.layers.21.mlp.down_proj.weight": "model_1.safetensors",
540
+ "model.layers.22.mixer.A_log": "model_1.safetensors",
541
+ "model.layers.22.mixer.D": "model_1.safetensors",
542
+ "model.layers.22.input_layernorm.weight": "model_1.safetensors",
543
+ "model.layers.22.post_attention_layernorm.weight": "model_1.safetensors",
544
+ "model.layers.22.mixer.dt_in_proj.weight": "model_1.safetensors",
545
+ "model.layers.22.mixer.conv1d.weight": "model_1.safetensors",
546
+ "model.layers.22.mixer.conv1d.bias": "model_1.safetensors",
547
+ "model.layers.22.mixer.dt_proj.bias": "model_1.safetensors",
548
+ "model.layers.22.mixer.in_proj.weight": "model_1.safetensors",
549
+ "model.layers.22.mixer.dt_proj.weight": "model_1.safetensors",
550
+ "model.layers.22.mixer.out_proj.weight": "model_1.safetensors",
551
+ "model.layers.22.mlp.gate_proj.weight": "model_1.safetensors",
552
+ "model.layers.22.mlp.up_proj.weight": "model_1.safetensors",
553
+ "model.layers.22.mlp.down_proj.weight": "model_1.safetensors",
554
+ "model.layers.23.mixer.A_log": "model_1.safetensors",
555
+ "model.layers.23.mixer.D": "model_1.safetensors",
556
+ "model.layers.23.input_layernorm.weight": "model_1.safetensors",
557
+ "model.layers.23.post_attention_layernorm.weight": "model_1.safetensors",
558
+ "model.layers.23.mixer.dt_in_proj.weight": "model_1.safetensors",
559
+ "model.layers.23.mixer.conv1d.weight": "model_1.safetensors",
560
+ "model.layers.23.mixer.conv1d.bias": "model_1.safetensors",
561
+ "model.layers.23.mixer.dt_proj.bias": "model_1.safetensors",
562
+ "model.layers.23.mixer.in_proj.weight": "model_1.safetensors",
563
+ "model.layers.23.mixer.dt_proj.weight": "model_1.safetensors",
564
+ "model.layers.23.mixer.out_proj.weight": "model_1.safetensors",
565
+ "model.layers.23.mlp.gate_proj.weight": "model_1.safetensors",
566
+ "model.layers.23.mlp.up_proj.weight": "model_1.safetensors",
567
+ "model.layers.23.mlp.down_proj.weight": "model_1.safetensors",
568
+ "model.layers.24.mixer.A_log": "model_1.safetensors",
569
+ "model.layers.24.mixer.D": "model_1.safetensors",
570
+ "model.layers.24.input_layernorm.weight": "model_1.safetensors",
571
+ "model.layers.24.post_attention_layernorm.weight": "model_1.safetensors",
572
+ "model.layers.24.mixer.dt_in_proj.weight": "model_1.safetensors",
573
+ "model.layers.24.mixer.conv1d.weight": "model_1.safetensors",
574
+ "model.layers.24.mixer.conv1d.bias": "model_1.safetensors",
575
+ "model.layers.24.mixer.dt_proj.bias": "model_1.safetensors",
576
+ "model.layers.24.mixer.in_proj.weight": "model_1.safetensors",
577
+ "model.layers.24.mixer.dt_proj.weight": "model_1.safetensors",
578
+ "model.layers.24.mixer.out_proj.weight": "model_1.safetensors",
579
+ "model.layers.24.mlp.gate_proj.weight": "model_1.safetensors",
580
+ "model.layers.24.mlp.up_proj.weight": "model_1.safetensors",
581
+ "model.layers.24.mlp.down_proj.weight": "model_1.safetensors",
582
+ "model.layers.25.mixer.A_log": "model_1.safetensors",
583
+ "model.layers.25.mixer.D": "model_1.safetensors",
584
+ "model.layers.25.input_layernorm.weight": "model_1.safetensors",
585
+ "model.layers.25.post_attention_layernorm.weight": "model_1.safetensors",
586
+ "model.layers.25.mixer.dt_in_proj.weight": "model_1.safetensors",
587
+ "model.layers.25.mixer.conv1d.weight": "model_1.safetensors",
588
+ "model.layers.25.mixer.conv1d.bias": "model_1.safetensors",
589
+ "model.layers.25.mixer.dt_proj.bias": "model_1.safetensors",
590
+ "model.layers.25.mixer.in_proj.weight": "model_1.safetensors",
591
+ "model.layers.25.mixer.dt_proj.weight": "model_1.safetensors",
592
+ "model.layers.25.mixer.out_proj.weight": "model_1.safetensors",
593
+ "model.layers.25.mlp.gate_proj.weight": "model_1.safetensors",
594
+ "model.layers.25.mlp.up_proj.weight": "model_1.safetensors",
595
+ "model.layers.25.mlp.down_proj.weight": "model_1.safetensors",
596
+ "model.layers.26.mixer.A_log": "model_1.safetensors",
597
+ "model.layers.26.mixer.D": "model_1.safetensors",
598
+ "model.layers.26.input_layernorm.weight": "model_1.safetensors",
599
+ "model.layers.26.post_attention_layernorm.weight": "model_1.safetensors",
600
+ "model.layers.26.mixer.dt_in_proj.weight": "model_1.safetensors",
601
+ "model.layers.26.mixer.conv1d.weight": "model_1.safetensors",
602
+ "model.layers.26.mixer.conv1d.bias": "model_1.safetensors",
603
+ "model.layers.26.mixer.dt_proj.bias": "model_1.safetensors",
604
+ "model.layers.26.mixer.in_proj.weight": "model_1.safetensors",
605
+ "model.layers.26.mixer.dt_proj.weight": "model_1.safetensors",
606
+ "model.layers.26.mixer.out_proj.weight": "model_1.safetensors",
607
+ "model.layers.26.mlp.gate_proj.weight": "model_1.safetensors",
608
+ "model.layers.26.mlp.up_proj.weight": "model_1.safetensors",
609
+ "model.layers.26.mlp.down_proj.weight": "model_1.safetensors",
610
+ "model.layers.27.mixer.A_log": "model_1.safetensors",
611
+ "model.layers.27.mixer.D": "model_1.safetensors",
612
+ "model.layers.27.input_layernorm.weight": "model_1.safetensors",
613
+ "model.layers.27.post_attention_layernorm.weight": "model_1.safetensors",
614
+ "model.layers.27.mixer.dt_in_proj.weight": "model_1.safetensors",
615
+ "model.layers.27.mixer.conv1d.weight": "model_1.safetensors",
616
+ "model.layers.27.mixer.conv1d.bias": "model_1.safetensors",
617
+ "model.layers.27.mixer.dt_proj.bias": "model_1.safetensors",
618
+ "model.layers.27.mixer.in_proj.weight": "model_1.safetensors",
619
+ "model.layers.27.mixer.dt_proj.weight": "model_1.safetensors",
620
+ "model.layers.27.mixer.out_proj.weight": "model_1.safetensors",
621
+ "model.layers.27.mlp.gate_proj.weight": "model_1.safetensors",
622
+ "model.layers.27.mlp.up_proj.weight": "model_1.safetensors",
623
+ "model.layers.27.mlp.down_proj.weight": "model_1.safetensors",
624
+ "model.layers.28.input_layernorm.weight": "model_2.safetensors",
625
+ "model.layers.28.post_attention_layernorm.weight": "model_2.safetensors",
626
+ "model.layers.28.self_attn.q_proj.weight": "model_2.safetensors",
627
+ "model.layers.28.self_attn.k_proj.weight": "model_2.safetensors",
628
+ "model.layers.28.self_attn.v_proj.weight": "model_2.safetensors",
629
+ "model.layers.28.self_attn.o_proj.weight": "model_2.safetensors",
630
+ "model.layers.28.mlp.gate_proj.weight": "model_2.safetensors",
631
+ "model.layers.28.mlp.up_proj.weight": "model_2.safetensors",
632
+ "model.layers.28.mlp.down_proj.weight": "model_2.safetensors",
633
+ "model.layers.29.input_layernorm.weight": "model_2.safetensors",
634
+ "model.layers.29.post_attention_layernorm.weight": "model_2.safetensors",
635
+ "model.layers.29.self_attn.q_proj.weight": "model_2.safetensors",
636
+ "model.layers.29.self_attn.k_proj.weight": "model_2.safetensors",
637
+ "model.layers.29.self_attn.v_proj.weight": "model_2.safetensors",
638
+ "model.layers.29.self_attn.o_proj.weight": "model_2.safetensors",
639
+ "model.layers.29.mlp.gate_proj.weight": "model_2.safetensors",
640
+ "model.layers.29.mlp.up_proj.weight": "model_2.safetensors",
641
+ "model.layers.29.mlp.down_proj.weight": "model_2.safetensors",
642
+ "model.layers.30.mixer.A_log": "model_2.safetensors",
643
+ "model.layers.30.mixer.D": "model_2.safetensors",
644
+ "model.layers.30.input_layernorm.weight": "model_2.safetensors",
645
+ "model.layers.30.post_attention_layernorm.weight": "model_2.safetensors",
646
+ "model.layers.30.mixer.dt_in_proj.weight": "model_2.safetensors",
647
+ "model.layers.30.mixer.conv1d.weight": "model_2.safetensors",
648
+ "model.layers.30.mixer.conv1d.bias": "model_2.safetensors",
649
+ "model.layers.30.mixer.dt_proj.bias": "model_2.safetensors",
650
+ "model.layers.30.mixer.in_proj.weight": "model_2.safetensors",
651
+ "model.layers.30.mixer.dt_proj.weight": "model_2.safetensors",
652
+ "model.layers.30.mixer.out_proj.weight": "model_2.safetensors",
653
+ "model.layers.30.mlp.gate_proj.weight": "model_2.safetensors",
654
+ "model.layers.30.mlp.up_proj.weight": "model_2.safetensors",
655
+ "model.layers.30.mlp.down_proj.weight": "model_2.safetensors",
656
+ "model.layers.31.mixer.A_log": "model_2.safetensors",
657
+ "model.layers.31.mixer.D": "model_2.safetensors",
658
+ "model.layers.31.input_layernorm.weight": "model_2.safetensors",
659
+ "model.layers.31.post_attention_layernorm.weight": "model_2.safetensors",
660
+ "model.layers.31.mixer.dt_in_proj.weight": "model_2.safetensors",
661
+ "model.layers.31.mixer.conv1d.weight": "model_2.safetensors",
662
+ "model.layers.31.mixer.conv1d.bias": "model_2.safetensors",
663
+ "model.layers.31.mixer.dt_proj.bias": "model_2.safetensors",
664
+ "model.layers.31.mixer.in_proj.weight": "model_2.safetensors",
665
+ "model.layers.31.mixer.dt_proj.weight": "model_2.safetensors",
666
+ "model.layers.31.mixer.out_proj.weight": "model_2.safetensors",
667
+ "model.layers.31.mlp.gate_proj.weight": "model_2.safetensors",
668
+ "model.layers.31.mlp.up_proj.weight": "model_2.safetensors",
669
+ "model.layers.31.mlp.down_proj.weight": "model_2.safetensors",
670
+ "model.layers.32.mixer.A_log": "model_2.safetensors",
671
+ "model.layers.32.mixer.D": "model_2.safetensors",
672
+ "model.layers.32.input_layernorm.weight": "model_2.safetensors",
673
+ "model.layers.32.post_attention_layernorm.weight": "model_2.safetensors",
674
+ "model.layers.32.mixer.dt_in_proj.weight": "model_2.safetensors",
675
+ "model.layers.32.mixer.conv1d.weight": "model_2.safetensors",
676
+ "model.layers.32.mixer.conv1d.bias": "model_2.safetensors",
677
+ "model.layers.32.mixer.dt_proj.bias": "model_2.safetensors",
678
+ "model.layers.32.mixer.in_proj.weight": "model_2.safetensors",
679
+ "model.layers.32.mixer.dt_proj.weight": "model_2.safetensors",
680
+ "model.layers.32.mixer.out_proj.weight": "model_2.safetensors",
681
+ "model.layers.32.mlp.gate_proj.weight": "model_2.safetensors",
682
+ "model.layers.32.mlp.up_proj.weight": "model_2.safetensors",
683
+ "model.layers.32.mlp.down_proj.weight": "model_2.safetensors",
684
+ "model.layers.33.mixer.A_log": "model_2.safetensors",
685
+ "model.layers.33.mixer.D": "model_2.safetensors",
686
+ "model.layers.33.input_layernorm.weight": "model_2.safetensors",
687
+ "model.layers.33.post_attention_layernorm.weight": "model_2.safetensors",
688
+ "model.layers.33.mixer.dt_in_proj.weight": "model_2.safetensors",
689
+ "model.layers.33.mixer.conv1d.weight": "model_2.safetensors",
690
+ "model.layers.33.mixer.conv1d.bias": "model_2.safetensors",
691
+ "model.layers.33.mixer.dt_proj.bias": "model_2.safetensors",
692
+ "model.layers.33.mixer.in_proj.weight": "model_2.safetensors",
693
+ "model.layers.33.mixer.dt_proj.weight": "model_2.safetensors",
694
+ "model.layers.33.mixer.out_proj.weight": "model_2.safetensors",
695
+ "model.layers.33.mlp.gate_proj.weight": "model_2.safetensors",
696
+ "model.layers.33.mlp.up_proj.weight": "model_2.safetensors",
697
+ "model.layers.33.mlp.down_proj.weight": "model_2.safetensors",
698
+ "model.layers.34.input_layernorm.weight": "model_2.safetensors",
699
+ "model.layers.34.post_attention_layernorm.weight": "model_2.safetensors",
700
+ "model.layers.34.self_attn.q_proj.weight": "model_2.safetensors",
701
+ "model.layers.34.self_attn.k_proj.weight": "model_2.safetensors",
702
+ "model.layers.34.self_attn.v_proj.weight": "model_2.safetensors",
703
+ "model.layers.34.self_attn.o_proj.weight": "model_2.safetensors",
704
+ "model.layers.34.mlp.gate_proj.weight": "model_2.safetensors",
705
+ "model.layers.34.mlp.up_proj.weight": "model_2.safetensors",
706
+ "model.layers.34.mlp.down_proj.weight": "model_2.safetensors",
707
+ "model.layers.35.mixer.A_log": "model_2.safetensors",
708
+ "model.layers.35.mixer.D": "model_2.safetensors",
709
+ "model.layers.35.input_layernorm.weight": "model_2.safetensors",
710
+ "model.layers.35.post_attention_layernorm.weight": "model_2.safetensors",
711
+ "model.layers.35.mixer.dt_in_proj.weight": "model_2.safetensors",
712
+ "model.layers.35.mixer.conv1d.weight": "model_2.safetensors",
713
+ "model.layers.35.mixer.conv1d.bias": "model_2.safetensors",
714
+ "model.layers.35.mixer.dt_proj.bias": "model_2.safetensors",
715
+ "model.layers.35.mixer.in_proj.weight": "model_2.safetensors",
716
+ "model.layers.35.mixer.dt_proj.weight": "model_2.safetensors",
717
+ "model.layers.35.mixer.out_proj.weight": "model_2.safetensors",
718
+ "model.layers.35.mlp.gate_proj.weight": "model_2.safetensors",
719
+ "model.layers.35.mlp.up_proj.weight": "model_2.safetensors",
720
+ "model.layers.35.mlp.down_proj.weight": "model_2.safetensors",
721
+ "model.layers.36.mixer.A_log": "model_2.safetensors",
722
+ "model.layers.36.mixer.D": "model_2.safetensors",
723
+ "model.layers.36.input_layernorm.weight": "model_2.safetensors",
724
+ "model.layers.36.post_attention_layernorm.weight": "model_2.safetensors",
725
+ "model.layers.36.mixer.dt_in_proj.weight": "model_2.safetensors",
726
+ "model.layers.36.mixer.conv1d.weight": "model_2.safetensors",
727
+ "model.layers.36.mixer.conv1d.bias": "model_2.safetensors",
728
+ "model.layers.36.mixer.dt_proj.bias": "model_2.safetensors",
729
+ "model.layers.36.mixer.in_proj.weight": "model_2.safetensors",
730
+ "model.layers.36.mixer.dt_proj.weight": "model_2.safetensors",
731
+ "model.layers.36.mixer.out_proj.weight": "model_2.safetensors",
732
+ "model.layers.36.mlp.gate_proj.weight": "model_2.safetensors",
733
+ "model.layers.36.mlp.up_proj.weight": "model_2.safetensors",
734
+ "model.layers.36.mlp.down_proj.weight": "model_2.safetensors",
735
+ "model.layers.37.mixer.A_log": "model_2.safetensors",
736
+ "model.layers.37.mixer.D": "model_2.safetensors",
737
+ "model.layers.37.input_layernorm.weight": "model_2.safetensors",
738
+ "model.layers.37.post_attention_layernorm.weight": "model_2.safetensors",
739
+ "model.layers.37.mixer.dt_in_proj.weight": "model_2.safetensors",
740
+ "model.layers.37.mixer.conv1d.weight": "model_2.safetensors",
741
+ "model.layers.37.mixer.conv1d.bias": "model_2.safetensors",
742
+ "model.layers.37.mixer.dt_proj.bias": "model_2.safetensors",
743
+ "model.layers.37.mixer.in_proj.weight": "model_2.safetensors",
744
+ "model.layers.37.mixer.dt_proj.weight": "model_2.safetensors",
745
+ "model.layers.37.mixer.out_proj.weight": "model_2.safetensors",
746
+ "model.layers.37.mlp.gate_proj.weight": "model_2.safetensors",
747
+ "model.layers.37.mlp.up_proj.weight": "model_2.safetensors",
748
+ "model.layers.37.mlp.down_proj.weight": "model_2.safetensors",
749
+ "model.layers.38.mixer.A_log": "model_2.safetensors",
750
+ "model.layers.38.mixer.D": "model_2.safetensors",
751
+ "model.layers.38.input_layernorm.weight": "model_2.safetensors",
752
+ "model.layers.38.post_attention_layernorm.weight": "model_2.safetensors",
753
+ "model.layers.38.mixer.dt_in_proj.weight": "model_2.safetensors",
754
+ "model.layers.38.mixer.conv1d.weight": "model_2.safetensors",
755
+ "model.layers.38.mixer.conv1d.bias": "model_2.safetensors",
756
+ "model.layers.38.mixer.dt_proj.bias": "model_2.safetensors",
757
+ "model.layers.38.mixer.in_proj.weight": "model_2.safetensors",
758
+ "model.layers.38.mixer.dt_proj.weight": "model_2.safetensors",
759
+ "model.layers.38.mixer.out_proj.weight": "model_2.safetensors",
760
+ "model.layers.38.mlp.gate_proj.weight": "model_2.safetensors",
761
+ "model.layers.38.mlp.up_proj.weight": "model_2.safetensors",
762
+ "model.layers.38.mlp.down_proj.weight": "model_2.safetensors",
763
+ "model.layers.39.mixer.A_log": "model_2.safetensors",
764
+ "model.layers.39.mixer.D": "model_2.safetensors",
765
+ "model.layers.39.input_layernorm.weight": "model_2.safetensors",
766
+ "model.layers.39.post_attention_layernorm.weight": "model_2.safetensors",
767
+ "model.layers.39.mixer.dt_in_proj.weight": "model_2.safetensors",
768
+ "model.layers.39.mixer.conv1d.weight": "model_2.safetensors",
769
+ "model.layers.39.mixer.conv1d.bias": "model_2.safetensors",
770
+ "model.layers.39.mixer.dt_proj.bias": "model_2.safetensors",
771
+ "model.layers.39.mixer.in_proj.weight": "model_2.safetensors",
772
+ "model.layers.39.mixer.dt_proj.weight": "model_2.safetensors",
773
+ "model.layers.39.mixer.out_proj.weight": "model_2.safetensors",
774
+ "model.layers.39.mlp.gate_proj.weight": "model_2.safetensors",
775
+ "model.layers.39.mlp.up_proj.weight": "model_2.safetensors",
776
+ "model.layers.39.mlp.down_proj.weight": "model_2.safetensors",
777
+ "model.layers.40.mixer.A_log": "model_2.safetensors",
778
+ "model.layers.40.mixer.D": "model_2.safetensors",
779
+ "model.layers.40.input_layernorm.weight": "model_2.safetensors",
780
+ "model.layers.40.post_attention_layernorm.weight": "model_2.safetensors",
781
+ "model.layers.40.mixer.dt_in_proj.weight": "model_2.safetensors",
782
+ "model.layers.40.mixer.conv1d.weight": "model_2.safetensors",
783
+ "model.layers.40.mixer.conv1d.bias": "model_2.safetensors",
784
+ "model.layers.40.mixer.dt_proj.bias": "model_2.safetensors",
785
+ "model.layers.40.mixer.in_proj.weight": "model_2.safetensors",
786
+ "model.layers.40.mixer.dt_proj.weight": "model_2.safetensors",
787
+ "model.layers.40.mixer.out_proj.weight": "model_2.safetensors",
788
+ "model.layers.40.mlp.gate_proj.weight": "model_2.safetensors",
789
+ "model.layers.40.mlp.up_proj.weight": "model_2.safetensors",
790
+ "model.layers.40.mlp.down_proj.weight": "model_2.safetensors",
791
+ "model.layers.41.mixer.A_log": "model_2.safetensors",
792
+ "model.layers.41.mixer.D": "model_2.safetensors",
793
+ "model.layers.41.input_layernorm.weight": "model_2.safetensors",
794
+ "model.layers.41.post_attention_layernorm.weight": "model_2.safetensors",
795
+ "model.layers.41.mixer.dt_in_proj.weight": "model_2.safetensors",
796
+ "model.layers.41.mixer.conv1d.weight": "model_2.safetensors",
797
+ "model.layers.41.mixer.conv1d.bias": "model_2.safetensors",
798
+ "model.layers.41.mixer.dt_proj.bias": "model_2.safetensors",
799
+ "model.layers.41.mixer.in_proj.weight": "model_2.safetensors",
800
+ "model.layers.41.mixer.dt_proj.weight": "model_2.safetensors",
801
+ "model.layers.41.mixer.out_proj.weight": "model_2.safetensors",
802
+ "model.layers.41.mlp.gate_proj.weight": "model_2.safetensors",
803
+ "model.layers.41.mlp.up_proj.weight": "model_2.safetensors",
804
+ "model.layers.41.mlp.down_proj.weight": "model_2.safetensors",
805
+ "model.layers.42.mixer.A_log": "model_2.safetensors",
806
+ "model.layers.42.mixer.D": "model_2.safetensors",
807
+ "model.layers.42.input_layernorm.weight": "model_2.safetensors",
808
+ "model.layers.42.post_attention_layernorm.weight": "model_2.safetensors",
809
+ "model.layers.42.mixer.dt_in_proj.weight": "model_2.safetensors",
810
+ "model.layers.42.mixer.conv1d.weight": "model_2.safetensors",
811
+ "model.layers.42.mixer.conv1d.bias": "model_2.safetensors",
812
+ "model.layers.42.mixer.dt_proj.bias": "model_2.safetensors",
813
+ "model.layers.42.mixer.in_proj.weight": "model_2.safetensors",
814
+ "model.layers.42.mixer.dt_proj.weight": "model_2.safetensors",
815
+ "model.layers.42.mixer.out_proj.weight": "model_2.safetensors",
816
+ "model.layers.42.mlp.gate_proj.weight": "model_2.safetensors",
817
+ "model.layers.42.mlp.up_proj.weight": "model_2.safetensors",
818
+ "model.layers.42.mlp.down_proj.weight": "model_3.safetensors",
819
+ "model.layers.43.mixer.A_log": "model_3.safetensors",
820
+ "model.layers.43.mixer.D": "model_3.safetensors",
821
+ "model.layers.43.input_layernorm.weight": "model_3.safetensors",
822
+ "model.layers.43.post_attention_layernorm.weight": "model_3.safetensors",
823
+ "model.layers.43.mixer.dt_in_proj.weight": "model_3.safetensors",
824
+ "model.layers.43.mixer.conv1d.weight": "model_3.safetensors",
825
+ "model.layers.43.mixer.conv1d.bias": "model_3.safetensors",
826
+ "model.layers.43.mixer.dt_proj.bias": "model_3.safetensors",
827
+ "model.layers.43.mixer.in_proj.weight": "model_3.safetensors",
828
+ "model.layers.43.mixer.dt_proj.weight": "model_3.safetensors",
829
+ "model.layers.43.mixer.out_proj.weight": "model_3.safetensors",
830
+ "model.layers.43.mlp.gate_proj.weight": "model_3.safetensors",
831
+ "model.layers.43.mlp.up_proj.weight": "model_3.safetensors",
832
+ "model.layers.43.mlp.down_proj.weight": "model_3.safetensors",
833
+ "model.layers.44.mixer.A_log": "model_3.safetensors",
834
+ "model.layers.44.mixer.D": "model_3.safetensors",
835
+ "model.layers.44.input_layernorm.weight": "model_3.safetensors",
836
+ "model.layers.44.post_attention_layernorm.weight": "model_3.safetensors",
837
+ "model.layers.44.mixer.dt_in_proj.weight": "model_3.safetensors",
838
+ "model.layers.44.mixer.conv1d.weight": "model_3.safetensors",
839
+ "model.layers.44.mixer.conv1d.bias": "model_3.safetensors",
840
+ "model.layers.44.mixer.dt_proj.bias": "model_3.safetensors",
841
+ "model.layers.44.mixer.in_proj.weight": "model_3.safetensors",
842
+ "model.layers.44.mixer.dt_proj.weight": "model_3.safetensors",
843
+ "model.layers.44.mixer.out_proj.weight": "model_3.safetensors",
844
+ "model.layers.44.mlp.gate_proj.weight": "model_3.safetensors",
845
+ "model.layers.44.mlp.up_proj.weight": "model_3.safetensors",
846
+ "model.layers.44.mlp.down_proj.weight": "model_3.safetensors",
847
+ "model.layers.45.mixer.A_log": "model_3.safetensors",
848
+ "model.layers.45.mixer.D": "model_3.safetensors",
849
+ "model.layers.45.input_layernorm.weight": "model_3.safetensors",
850
+ "model.layers.45.post_attention_layernorm.weight": "model_3.safetensors",
851
+ "model.layers.45.mixer.dt_in_proj.weight": "model_3.safetensors",
852
+ "model.layers.45.mixer.conv1d.weight": "model_3.safetensors",
853
+ "model.layers.45.mixer.conv1d.bias": "model_3.safetensors",
854
+ "model.layers.45.mixer.dt_proj.bias": "model_3.safetensors",
855
+ "model.layers.45.mixer.in_proj.weight": "model_3.safetensors",
856
+ "model.layers.45.mixer.dt_proj.weight": "model_3.safetensors",
857
+ "model.layers.45.mixer.out_proj.weight": "model_3.safetensors",
858
+ "model.layers.45.mlp.gate_proj.weight": "model_3.safetensors",
859
+ "model.layers.45.mlp.up_proj.weight": "model_3.safetensors",
860
+ "model.layers.45.mlp.down_proj.weight": "model_3.safetensors",
861
+ "model.layers.46.mixer.A_log": "model_3.safetensors",
862
+ "model.layers.46.mixer.D": "model_3.safetensors",
863
+ "model.layers.46.input_layernorm.weight": "model_3.safetensors",
864
+ "model.layers.46.post_attention_layernorm.weight": "model_3.safetensors",
865
+ "model.layers.46.mixer.dt_in_proj.weight": "model_3.safetensors",
866
+ "model.layers.46.mixer.conv1d.weight": "model_3.safetensors",
867
+ "model.layers.46.mixer.conv1d.bias": "model_3.safetensors",
868
+ "model.layers.46.mixer.dt_proj.bias": "model_3.safetensors",
869
+ "model.layers.46.mixer.in_proj.weight": "model_3.safetensors",
870
+ "model.layers.46.mixer.dt_proj.weight": "model_3.safetensors",
871
+ "model.layers.46.mixer.out_proj.weight": "model_3.safetensors",
872
+ "model.layers.46.mlp.gate_proj.weight": "model_3.safetensors",
873
+ "model.layers.46.mlp.up_proj.weight": "model_3.safetensors",
874
+ "model.layers.46.mlp.down_proj.weight": "model_3.safetensors",
875
+ "model.layers.47.mixer.A_log": "model_3.safetensors",
876
+ "model.layers.47.mixer.D": "model_3.safetensors",
877
+ "model.layers.47.input_layernorm.weight": "model_3.safetensors",
878
+ "model.layers.47.post_attention_layernorm.weight": "model_3.safetensors",
879
+ "model.layers.47.mixer.dt_in_proj.weight": "model_3.safetensors",
880
+ "model.layers.47.mixer.conv1d.weight": "model_3.safetensors",
881
+ "model.layers.47.mixer.conv1d.bias": "model_3.safetensors",
882
+ "model.layers.47.mixer.dt_proj.bias": "model_3.safetensors",
883
+ "model.layers.47.mixer.in_proj.weight": "model_3.safetensors",
884
+ "model.layers.47.mixer.dt_proj.weight": "model_3.safetensors",
885
+ "model.layers.47.mixer.out_proj.weight": "model_3.safetensors",
886
+ "model.layers.47.mlp.gate_proj.weight": "model_3.safetensors",
887
+ "model.layers.47.mlp.up_proj.weight": "model_3.safetensors",
888
+ "model.layers.47.mlp.down_proj.weight": "model_3.safetensors",
889
+ "model.layers.48.input_layernorm.weight": "model_3.safetensors",
890
+ "model.layers.48.post_attention_layernorm.weight": "model_3.safetensors",
891
+ "model.layers.48.self_attn.q_proj.weight": "model_3.safetensors",
892
+ "model.layers.48.self_attn.k_proj.weight": "model_3.safetensors",
893
+ "model.layers.48.self_attn.v_proj.weight": "model_3.safetensors",
894
+ "model.layers.48.self_attn.o_proj.weight": "model_3.safetensors",
895
+ "model.layers.48.mlp.gate_proj.weight": "model_3.safetensors",
896
+ "model.layers.48.mlp.up_proj.weight": "model_3.safetensors",
897
+ "model.layers.48.mlp.down_proj.weight": "model_3.safetensors",
898
+ "model.layers.49.mixer.A_log": "model_3.safetensors",
899
+ "model.layers.49.mixer.D": "model_3.safetensors",
900
+ "model.layers.49.input_layernorm.weight": "model_3.safetensors",
901
+ "model.layers.49.post_attention_layernorm.weight": "model_3.safetensors",
902
+ "model.layers.49.mixer.dt_in_proj.weight": "model_3.safetensors",
903
+ "model.layers.49.mixer.conv1d.weight": "model_3.safetensors",
904
+ "model.layers.49.mixer.conv1d.bias": "model_3.safetensors",
905
+ "model.layers.49.mixer.dt_proj.bias": "model_3.safetensors",
906
+ "model.layers.49.mixer.in_proj.weight": "model_3.safetensors",
907
+ "model.layers.49.mixer.dt_proj.weight": "model_3.safetensors",
908
+ "model.layers.49.mixer.out_proj.weight": "model_3.safetensors",
909
+ "model.layers.49.mlp.gate_proj.weight": "model_3.safetensors",
910
+ "model.layers.49.mlp.up_proj.weight": "model_3.safetensors",
911
+ "model.layers.49.mlp.down_proj.weight": "model_3.safetensors",
912
+ "model.norm.weight": "model_3.safetensors",
913
+ "lm_head.weight": "model_3.safetensors"
914
+ }
915
+ }
model_0.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:54ab7d8a72b9ed8dddb280c6b894a0a7c6404a6545cf77821e075708353ec122
3
+ size 17341952504
model_1.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3cf3c4655d0a8f554e50c3c8001f92f681edb1a68a18eae6f4c55857756ed141
3
+ size 17415079472
model_2.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:03913e3b49ba8415e659d23f28e30f55a973ff5b669ef8e323226331e7f98e84
3
+ size 17217537984
model_3.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e5288d374646a02d5bec8810577e6955c90e810fd958bcd1c835d17a1d45e755
3
+ size 11188268104
modeling_apriel_h.py ADDED
@@ -0,0 +1,908 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import math
3
+ from dataclasses import dataclass
4
+ from typing import Any, Optional, Union
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
9
+ from .configuration_apriel_h import AprielHConfig
10
+ from einops import rearrange, repeat
11
+ from mamba_ssm.ops.selective_scan_interface import selective_scan_fn
12
+ from mamba_ssm.ops.triton.selective_state_update import selective_state_update
13
+ from torch import nn
14
+ from transformers import GenerationMixin
15
+ from transformers.cache_utils import Cache, DynamicCache
16
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
17
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
18
+ from transformers.modeling_utils import PreTrainedModel
19
+ from transformers.models.mistral.modeling_mistral import MistralDecoderLayer, MistralMLP, MistralModel, MistralRMSNorm
20
+ from transformers.processing_utils import Unpack
21
+ from transformers.utils import LossKwargs, can_return_tuple, logging
22
+ from transformers.utils.generic import ModelOutput
23
+
24
+ logger = logging.get_logger(__name__)
25
+
26
+
27
+ is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update))
28
+
29
+
30
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
31
+ """
32
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
33
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
34
+ """
35
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
36
+ if n_rep == 1:
37
+ return hidden_states
38
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
39
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
40
+
41
+
42
+ # Copied from https://github.com/huggingface/transformers/blob/main/src/transformers/models/jamba/modeling_jamba.py
43
+ class HybridMambaAttentionDynamicCache(DynamicCache):
44
+ """
45
+ A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache
46
+ (which has a constant shape regardless of seq_len).
47
+ This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states`
48
+ and `ssm_states` for mamba cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor
49
+ For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`,
50
+ while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors).
51
+ For mamba layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors),
52
+ while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`,
53
+ and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`.
54
+ """
55
+
56
+ def __init__(self, config: AprielHConfig, batch_size, dtype=torch.float16, device=None):
57
+ super().__init__()
58
+ self.dtype = dtype
59
+ self.hybrid_override_pattern = config.hybrid_block_layout
60
+ self.has_previous_state = False # only used by mamba
61
+ intermediate_size = (
62
+ config.ssm_cfg["d_inner"]
63
+ if config.ssm_cfg["d_inner"] is not None
64
+ else config.ssm_cfg["expand"] * config.hidden_size
65
+ )
66
+ ssm_state_size = config.ssm_cfg["d_state"]
67
+ conv_kernel_size = config.ssm_cfg["d_conv"]
68
+ self.n_qk_heads = config.ssm_cfg["n_qk_heads"]
69
+ self.num_C_head = intermediate_size // ssm_state_size # mamba2
70
+ assert intermediate_size % self.n_qk_heads == 0, "d_inner must be divisible by n_qk_heads"
71
+ self.head_d = intermediate_size // self.n_qk_heads
72
+ self.conv_states = []
73
+ self.ssm_states = []
74
+ self.transformer_layers = []
75
+ for i in range(config.num_hidden_layers):
76
+ if self.hybrid_override_pattern[i] == "m2d":
77
+ # Mamba layer
78
+ self.conv_states += [
79
+ torch.zeros(
80
+ batch_size,
81
+ conv_kernel_size,
82
+ intermediate_size + 2 * self.n_qk_heads * ssm_state_size,
83
+ device=device,
84
+ dtype=dtype,
85
+ ).transpose(1, 2)
86
+ ]
87
+ self.ssm_states += [
88
+ torch.zeros(batch_size, self.n_qk_heads, self.head_d, ssm_state_size, device=device, dtype=dtype)
89
+ ]
90
+ elif self.hybrid_override_pattern[i] == "m2":
91
+ if "repeat_kv_before_conv" in config.ssm_cfg:
92
+ assert (
93
+ config.ssm_cfg["repeat_kv_before_conv"] == True
94
+ ), "Only support repeat_kv_before_conv=True for m2 for now"
95
+
96
+ self.conv_states += [
97
+ torch.zeros(
98
+ batch_size,
99
+ intermediate_size,
100
+ conv_kernel_size,
101
+ device=device,
102
+ dtype=dtype,
103
+ )
104
+ ]
105
+ self.ssm_states += [
106
+ torch.zeros(
107
+ batch_size,
108
+ self.num_C_head,
109
+ intermediate_size // self.num_C_head,
110
+ ssm_state_size,
111
+ device=device,
112
+ dtype=dtype,
113
+ )
114
+ ]
115
+ else:
116
+ # Attention or MLP layer
117
+ self.conv_states += [torch.tensor([[]] * batch_size, device=device)]
118
+ self.ssm_states += [torch.tensor([[]] * batch_size, device=device)]
119
+ self.transformer_layers.append(i)
120
+
121
+ self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)]
122
+ self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)]
123
+
124
+ def update(
125
+ self,
126
+ key_states: torch.Tensor,
127
+ value_states: torch.Tensor,
128
+ layer_idx: int,
129
+ cache_kwargs: Optional[dict[str, Any]] = None,
130
+ ) -> tuple[torch.Tensor, torch.Tensor]:
131
+ # Update the cache
132
+ if self.key_cache[layer_idx].shape[-1] == 0:
133
+ self.key_cache[layer_idx] = key_states
134
+ self.value_cache[layer_idx] = value_states
135
+ else:
136
+ self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=2)
137
+ self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=2)
138
+
139
+ return self.key_cache[layer_idx], self.value_cache[layer_idx]
140
+
141
+ def reorder_cache(self, beam_idx: torch.LongTensor):
142
+ """Reorders the cache for beam search, given the selected beam indices."""
143
+ for layer_idx in range(len(self.key_cache)):
144
+ device = self.key_cache[layer_idx].device
145
+ self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
146
+ device = self.value_cache[layer_idx].device
147
+ self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
148
+
149
+ device = self.conv_states[layer_idx].device
150
+ self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx.to(device))
151
+ device = self.ssm_states[layer_idx].device
152
+ self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(0, beam_idx.to(device))
153
+
154
+ def to_legacy_cache(self) -> tuple[tuple[torch.Tensor], tuple[torch.Tensor]]:
155
+ raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.")
156
+
157
+ @classmethod
158
+ def from_legacy_cache(cls, past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None) -> "DynamicCache":
159
+ raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.")
160
+
161
+ # Copied from modeling_mamba2.py
162
+ def update_conv_state(
163
+ self, layer_idx: int, new_conv_state: torch.Tensor, cache_init: bool = False
164
+ ) -> torch.Tensor:
165
+ if cache_init:
166
+ self.conv_states[layer_idx] = new_conv_state.to(self.conv_states.device)
167
+ else:
168
+ self.conv_states[layer_idx] = self.conv_states[layer_idx].roll(shifts=-1, dims=-1)
169
+ self.conv_states[layer_idx][:, :, -1] = new_conv_state[:, 0, :].to(self.conv_states.device)
170
+ return self.conv_states[layer_idx]
171
+
172
+ def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor):
173
+ self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states.device)
174
+ return self.ssm_states[layer_idx]
175
+
176
+ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
177
+ """Returns the sequence length of the cached states. A layer index can be optionally passed."""
178
+ # take any layer that contains cache and not empty tensor
179
+ layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx
180
+ if len(self.key_cache) <= layer_idx:
181
+ return 0
182
+ is_empty_layer = (
183
+ len(self.key_cache) == 0 # no cache in any layer
184
+ or len(self.key_cache) <= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it
185
+ or not self.key_cache[layer_idx].numel() # the layer has no cache
186
+ )
187
+ return self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0
188
+ # return self.key_cache[layer_idx].shape[-2]
189
+
190
+ def reset(self):
191
+ self.conv_states.zero_()
192
+ self.ssm_states.zero_()
193
+
194
+ def to_legacy_cache(self) -> tuple[tuple[torch.Tensor], tuple[torch.Tensor]]:
195
+ raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.")
196
+
197
+ @classmethod
198
+ def from_legacy_cache(cls, past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None) -> "DynamicCache":
199
+ raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.")
200
+
201
+
202
+ @dataclass
203
+ class AprielHybridCausalOutput(ModelOutput):
204
+ """Custom output class for MambaLMHeadModel."""
205
+
206
+ loss: Optional[torch.FloatTensor] = None
207
+ logits: Optional[torch.FloatTensor] = None
208
+ all_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
209
+ last_hidden_state: Optional[torch.FloatTensor] = None
210
+ attention_weights: Optional[torch.FloatTensor] = None
211
+ past_key_values: Optional[Cache] = None
212
+
213
+
214
+ def segsum(x):
215
+ """More stable segment sum calculation."""
216
+ # [1, 2, 3]
217
+ T = x.size(-1)
218
+ x = repeat(x, "... d -> ... d e", e=T)
219
+ # [[1, 1, 1], [2, 2, 2], [3, 3, 3]]
220
+ mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=-1)
221
+ x = x.masked_fill(~mask, 0)
222
+ # [[0, 0, 0], [2, 0, 0], [3, 3, 0]]
223
+ x_segsum = torch.cumsum(x, dim=-2)
224
+ # [[0, 0, 0], [2, 0, 0], [5, 3, 0]]
225
+ mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0)
226
+ x_segsum = x_segsum.masked_fill(~mask, -torch.inf)
227
+ return x_segsum
228
+
229
+
230
+ def materialize_mixer(A_log, B, C, D):
231
+ """
232
+ Since the transfer matrix will be equated to the attention matrix,
233
+ we need to support the form: torch.matmul(attn_weights, value_states).
234
+ Thus, y = torch.matmul(T, X)
235
+ Arguments:
236
+ A_log: (batch, length, n_heads)
237
+ B: (batch, length, n_heads, d_state)
238
+ C: (batch, length, n_heads, d_state)
239
+ Return:
240
+ T: (batch, n_heads, length, length)
241
+ """
242
+ batch_size, length, n_heads, d_state = B.shape
243
+ assert A_log.shape == (batch_size, length, n_heads)
244
+ assert B.shape == C.shape == (batch_size, length, n_heads, d_state)
245
+
246
+ # Compute:
247
+ A_log = rearrange(-F.softplus(A_log), "b l h -> b h l")
248
+ powers = torch.exp(segsum(A_log))
249
+ T = torch.einsum("blhn,bshn,bhls->bhsl", C, B, powers)
250
+
251
+ # Add D:
252
+ if D is not None:
253
+ T[:, :, torch.arange(length), torch.arange(length)] += D.view(1, n_heads, 1)
254
+
255
+ T = rearrange(T, "b h z l -> b h l z")
256
+ return T
257
+
258
+
259
+ def apply_mask_to_padding_states(hidden_states, attention_mask):
260
+ """
261
+ Tunes out the hidden states for padding tokens, see https://github.com/state-spaces/mamba/issues/66
262
+ """
263
+ if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
264
+ dtype = hidden_states.dtype
265
+ hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)
266
+
267
+ return hidden_states
268
+
269
+
270
+ class Mamba(nn.Module):
271
+ def __init__(
272
+ self,
273
+ d_model,
274
+ d_inner,
275
+ d_xb=None,
276
+ d_state=16,
277
+ d_conv=4,
278
+ expand=2,
279
+ dt_rank="auto",
280
+ dt_min=0.001,
281
+ dt_max=0.1,
282
+ dt_init="random",
283
+ dt_scale=1.0,
284
+ dt_init_floor=1e-4,
285
+ repeat_kv_before_conv=True,
286
+ conv_bias=True,
287
+ bias=False,
288
+ dt_proj_bias=True,
289
+ use_fast_path=True, # Fused kernel options
290
+ layer_idx=None,
291
+ device=None,
292
+ dtype=None,
293
+ **kwargs,
294
+ ):
295
+ factory_kwargs = {"device": device, "dtype": dtype}
296
+ super().__init__()
297
+ self.d_model = d_model
298
+ self.d_xb = d_xb if d_xb is not None else d_model
299
+ self.d_state = d_state
300
+ self.d_conv = d_conv
301
+ self.expand = expand
302
+ self.d_inner = d_inner if d_inner is not None else int(self.expand * self.d_model)
303
+ self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
304
+ self.use_fast_path = use_fast_path
305
+ self.layer_idx = layer_idx
306
+ self.repeat_kv_before_conv = repeat_kv_before_conv
307
+
308
+ if self.repeat_kv_before_conv:
309
+ self.conv1d = nn.Conv1d(
310
+ in_channels=self.d_inner,
311
+ out_channels=self.d_inner,
312
+ bias=conv_bias,
313
+ kernel_size=d_conv,
314
+ groups=self.d_inner,
315
+ padding=d_conv - 1,
316
+ **factory_kwargs,
317
+ )
318
+ else:
319
+ self.conv1d = nn.Conv1d(
320
+ in_channels=self.d_xb,
321
+ out_channels=self.d_xb,
322
+ bias=conv_bias,
323
+ kernel_size=d_conv,
324
+ groups=self.d_xb,
325
+ padding=d_conv - 1,
326
+ **factory_kwargs,
327
+ )
328
+
329
+ self.activation = "silu"
330
+ self.act = nn.SiLU()
331
+
332
+ self.num_xb_head = self.d_xb // self.d_state
333
+ self.num_C_head = self.d_inner // self.d_state
334
+ self.repeat_group = self.num_C_head // self.num_xb_head
335
+
336
+ self.in_proj = nn.Linear(self.d_model, 2 * self.d_xb + 2 * self.d_inner, bias=bias, **factory_kwargs)
337
+ self.dt_in_proj = nn.Linear(self.d_model, self.dt_rank, bias=bias, **factory_kwargs)
338
+ self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=dt_proj_bias, **factory_kwargs)
339
+
340
+ # Initialize special dt projection to preserve variance at initialization
341
+ dt_init_std = self.dt_rank**-0.5 * dt_scale
342
+ if dt_init == "constant":
343
+ nn.init.constant_(self.dt_proj.weight, dt_init_std)
344
+ elif dt_init == "random":
345
+ nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
346
+ else:
347
+ raise NotImplementedError
348
+
349
+ # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
350
+ dt = torch.exp(
351
+ torch.rand(self.d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min)
352
+ ).clamp(min=dt_init_floor)
353
+ # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
354
+ inv_dt = dt + torch.log(-torch.expm1(-dt))
355
+ with torch.no_grad():
356
+ self.dt_proj.bias.copy_(inv_dt)
357
+ # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
358
+ self.dt_proj.bias._no_reinit = True
359
+
360
+ # S4D real initialization
361
+ A = repeat(
362
+ torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),
363
+ "n -> d n",
364
+ d=self.d_inner,
365
+ ).contiguous()
366
+ A_log = torch.log(A) # Keep A_log in fp32
367
+ self.A_log = nn.Parameter(A_log)
368
+ self.A_log._no_weight_decay = True
369
+
370
+ # D "skip" parameter
371
+ self.D = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32
372
+ self.D._no_weight_decay = True
373
+
374
+ self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
375
+
376
+ def forward(
377
+ self,
378
+ hidden_states: torch.Tensor,
379
+ past_key_value: Optional[HybridMambaAttentionDynamicCache] = None,
380
+ mamba_mask: Optional[torch.Tensor] = None,
381
+ return_mixer_matrix=False,
382
+ **kwargs,
383
+ ):
384
+ """
385
+ hidden_states: (B, L, D)
386
+ Returns: same shape as hidden_states
387
+ """
388
+ assert is_fast_path_available and "cuda" in self.in_proj.weight.device.type, "Only support fast path on cuda"
389
+ cache_position = kwargs.get("cache_position", None)
390
+ batch, seqlen, dim = hidden_states.shape
391
+
392
+ ssm_state, conv_state = None, None
393
+ use_precomputed_states = False
394
+
395
+ #########################################################
396
+ # Quick and dirty to work with CG
397
+ if "inference_params" in kwargs:
398
+ seqlen_offset = kwargs["inference_params"].seqlen_offset
399
+ if seqlen_offset > 0:
400
+ use_precomputed_states = True
401
+ else:
402
+ seqlen_offset = kwargs.get("seqlen_offset", cache_position[0]) if cache_position is not None else 0
403
+ use_precomputed_states = (
404
+ past_key_value is not None
405
+ and past_key_value.has_previous_state
406
+ and seqlen == 1
407
+ and past_key_value.conv_states[self.layer_idx].shape[0]
408
+ == past_key_value.ssm_states[self.layer_idx].shape[0]
409
+ == batch
410
+ and cache_position is not None
411
+ and seqlen_offset > 0
412
+ )
413
+ #########################################################
414
+
415
+ ssm_state, conv_state = self._get_states_from_cache(past_key_value, batch)
416
+ if use_precomputed_states:
417
+ out, _, _ = self.step(hidden_states, conv_state, ssm_state)
418
+ return {"hidden_states": out}
419
+
420
+ outputs = {}
421
+ A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
422
+
423
+ zxbc = self.in_proj(hidden_states)
424
+ z, x, B, C = torch.split(
425
+ zxbc,
426
+ [
427
+ self.d_inner,
428
+ self.d_xb,
429
+ self.d_xb,
430
+ self.d_inner,
431
+ ],
432
+ dim=-1,
433
+ )
434
+
435
+ x = rearrange(x, "b l d -> b d l")
436
+ z = rearrange(z, "b l d -> b d l")
437
+
438
+ B = rearrange(B, "b l (n_group dstate) -> b n_group l dstate", dstate=self.d_state)
439
+ B = repeat_kv(B, self.repeat_group) # B, n_group, L, H
440
+ B = rearrange(B, "b n_group l dstate -> b n_group dstate l").contiguous()
441
+ C = rearrange(C, "b l (n_group dstate) -> b n_group dstate l", dstate=self.d_state).contiguous()
442
+
443
+ dt = self.dt_proj(self.dt_in_proj(hidden_states)) # B, L, d_inner
444
+ dt = rearrange(dt, "b l d -> b d l") # B, d_inner, L
445
+
446
+ if self.repeat_kv_before_conv:
447
+ x = rearrange(x, "b (n_group dstate) l -> b n_group l dstate", dstate=self.d_state)
448
+ x = repeat_kv(x, self.repeat_group)
449
+ x = rearrange(x, "b n_group l dstate -> b (n_group dstate) l")
450
+
451
+ # Compute short convolution
452
+ if conv_state is not None:
453
+ # If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
454
+ # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
455
+ # Update state (B D W)
456
+ conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0)))
457
+ if causal_conv1d_fn is None:
458
+ x = self.act(self.conv1d(x)[..., :seqlen]).transpose(1, 2)
459
+ else:
460
+ assert self.activation in ["silu", "swish"]
461
+ x = causal_conv1d_fn(
462
+ x=x,
463
+ weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
464
+ bias=self.conv1d.bias,
465
+ activation=self.activation,
466
+ )
467
+
468
+ if not self.repeat_kv_before_conv:
469
+ x = rearrange(x, "b (n_group dstate) l -> b n_group l dstate", dstate=self.d_state)
470
+ x = repeat_kv(x, self.repeat_group)
471
+ x = rearrange(x, "b n_group l dstate -> b (n_group dstate) l")
472
+
473
+ y = selective_scan_fn(
474
+ x,
475
+ dt,
476
+ A,
477
+ B,
478
+ C,
479
+ self.D.float(),
480
+ z=z,
481
+ delta_bias=self.dt_proj.bias.float(),
482
+ delta_softplus=True,
483
+ return_last_state=(ssm_state is not None),
484
+ )
485
+
486
+ if ssm_state is not None:
487
+ y, last_state = y
488
+ ssm_state.copy_(rearrange(last_state, "b (h d) n -> b h d n", h=self.num_C_head))
489
+
490
+ y = rearrange(y, "b d l -> b l d")
491
+ out = self.out_proj(y)
492
+
493
+ outputs["hidden_states"] = out[:, :seqlen, :]
494
+ return outputs
495
+
496
+ def step(self, hidden_states, conv_state, ssm_state):
497
+ dtype = hidden_states.dtype
498
+ assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now"
499
+
500
+ hidden_states_input = hidden_states.squeeze(1)
501
+
502
+ A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
503
+
504
+ zxbc = self.in_proj(hidden_states_input)
505
+ z, x, B, C = torch.split(zxbc, [self.d_inner, self.d_xb, self.d_xb, self.d_inner], dim=-1)
506
+
507
+ B = rearrange(B, "b (n_group dstate) -> b n_group dstate", dstate=self.d_state)
508
+ B = torch.repeat_interleave(B, dim=1, repeats=self.repeat_group)
509
+ C = rearrange(C, "b (n_group dstate) -> b n_group dstate", dstate=self.d_state).contiguous()
510
+
511
+ dt = self.dt_proj(self.dt_in_proj(hidden_states_input)) # B, d_inner
512
+
513
+ if self.repeat_kv_before_conv:
514
+ x = rearrange(x, "b (n_group dstate) -> b n_group dstate", dstate=self.d_state)
515
+ x = torch.repeat_interleave(x, dim=1, repeats=self.repeat_group)
516
+ x = rearrange(x, "b n_group dstate -> b (n_group dstate)")
517
+
518
+ # Conv step
519
+ if causal_conv1d_update is None:
520
+ # Update state (B D W)
521
+ conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1))
522
+ conv_state[:, :, -1] = x
523
+ x = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D)
524
+ if self.conv1d.bias is not None:
525
+ x = x + self.conv1d.bias
526
+ x = self.act(x).to(dtype=dtype)
527
+ else:
528
+ x = causal_conv1d_update(
529
+ x,
530
+ conv_state,
531
+ rearrange(self.conv1d.weight, "d 1 w -> d w"),
532
+ self.conv1d.bias,
533
+ self.activation,
534
+ )
535
+
536
+ if not self.repeat_kv_before_conv:
537
+ x = rearrange(x, "b (n_group dstate) -> b n_group dstate", dstate=self.d_state)
538
+ x = torch.repeat_interleave(x, dim=1, repeats=self.repeat_group)
539
+ x = rearrange(x, "b n_group dstate -> b (n_group dstate)")
540
+
541
+ x = rearrange(x, "b (h d) -> b h d", h=self.num_C_head)
542
+ dt = rearrange(dt, "b (h d) -> b h d", h=self.num_C_head)
543
+ A = rearrange(A, "(h d) n -> h d n", h=self.num_C_head)
544
+ D = rearrange(self.D, "(h d) -> h d", h=self.num_C_head)
545
+ z = rearrange(z, "b (h d) -> b h d", h=self.num_C_head)
546
+ dt_bias = rearrange(self.dt_proj.bias, "(h d) -> h d", h=self.num_C_head)
547
+
548
+ # SSM step
549
+ assert selective_state_update is not None
550
+ y = selective_state_update(ssm_state, x, dt, A, B, C, D, z=z, dt_bias=dt_bias, dt_softplus=True)
551
+ y = rearrange(y, "b h d -> b (h d)")
552
+ out = self.out_proj(y)
553
+
554
+ return out.unsqueeze(1), conv_state, ssm_state
555
+
556
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
557
+ device = self.out_proj.weight.device
558
+ conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype
559
+ if self.repeat_kv_before_conv:
560
+ conv_state = torch.zeros(batch_size, self.d_inner, self.d_conv, device=device, dtype=conv_dtype)
561
+ else:
562
+ conv_state = torch.zeros(batch_size, self.d_xb, self.d_conv, device=device, dtype=conv_dtype)
563
+ ssm_dtype = self.dt_proj.weight.dtype if dtype is None else dtype
564
+ ssm_state = torch.zeros(
565
+ batch_size, self.num_C_head, self.d_inner // self.num_C_head, self.d_state, device=device, dtype=ssm_dtype
566
+ )
567
+ return conv_state, ssm_state
568
+
569
+ def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False):
570
+ """
571
+ conv_state: (batch, d_conv, conv1d.weight.shape[0])
572
+ ssm_state: (batch, n_qk_heads, headdim, d_state)
573
+ """
574
+ assert self.layer_idx is not None
575
+
576
+ # Get states
577
+ ssm_states = inference_params.ssm_states[self.layer_idx]
578
+ conv_states = inference_params.conv_states[self.layer_idx]
579
+ if initialize_states:
580
+ ssm_states.zero_()
581
+ conv_states.zero_()
582
+ return ssm_states, conv_states
583
+
584
+
585
+ class AprielSSMM2DecoderLayer(nn.Module):
586
+ _mixer_class = Mamba
587
+
588
+ def __init__(self, config: AprielHConfig, layer_idx: int, device=None, dtype=None, **kwargs):
589
+ super().__init__(**kwargs)
590
+ factory_kwargs = {"device": device, "dtype": dtype}
591
+ self.hidden_size = config.hidden_size
592
+
593
+ self.mixer = self._mixer_class(
594
+ d_model=config.hidden_size,
595
+ layer_idx=layer_idx,
596
+ **config.ssm_cfg,
597
+ **factory_kwargs,
598
+ )
599
+
600
+ self.mlp = MistralMLP(config)
601
+ self.input_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
602
+ self.post_attention_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
603
+
604
+ def forward(
605
+ self, hidden_states: torch.Tensor, **kwargs
606
+ ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
607
+
608
+ outputs = {}
609
+ residual = hidden_states
610
+
611
+ hidden_states = self.input_layernorm(hidden_states)
612
+
613
+ mixer_outputs = self.mixer(
614
+ hidden_states,
615
+ **kwargs,
616
+ )
617
+
618
+ hidden_states = mixer_outputs["hidden_states"].to(residual.dtype) + residual
619
+
620
+ # Fully Connected
621
+ residual = hidden_states
622
+ hidden_states = self.post_attention_layernorm(hidden_states)
623
+ hidden_states = self.mlp(hidden_states)
624
+ hidden_states = residual + hidden_states
625
+
626
+ outputs = (hidden_states,)
627
+
628
+ return outputs
629
+
630
+
631
+ class AprielHybridIdentity(nn.Module):
632
+ def __init__(self, config: AprielHConfig):
633
+ super().__init__()
634
+ self.config = config
635
+
636
+ def forward(self, hidden_states: torch.Tensor, **kwargs):
637
+ return (hidden_states,)
638
+
639
+
640
+ class AprielHModel(MistralModel):
641
+ """
642
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`AprielDecoderLayer`, `AprielSSMDecoderLayer`]
643
+ Args:
644
+ config: AprielSSMHybridConfig
645
+ """
646
+
647
+ def __init__(self, config: AprielHConfig, **kwargs):
648
+ config_copy = copy.deepcopy(config)
649
+ config_copy.num_hidden_layers = 0
650
+ super().__init__(config_copy, **kwargs)
651
+ self.config = config
652
+ blocks = []
653
+ logger.info(f"Loading hyubrid model with the following layout: {config.hybrid_block_layout}")
654
+ for layer_idx, type in enumerate(config.hybrid_block_layout):
655
+ if type == "m2":
656
+ blocks.append(AprielSSMM2DecoderLayer(config, layer_idx))
657
+ elif type == "t":
658
+ blocks.append(MistralDecoderLayer(config, layer_idx))
659
+ elif type == "i":
660
+ blocks.append(AprielHybridIdentity(config))
661
+ else:
662
+ raise ValueError(f"Invalid block type: {type}")
663
+ self.layers = nn.ModuleList(blocks)
664
+
665
+ # Initialize weights and apply final processing
666
+ self.post_init()
667
+
668
+ @can_return_tuple
669
+ def forward(
670
+ self,
671
+ input_ids: Optional[torch.LongTensor] = None,
672
+ attention_mask: Optional[torch.Tensor] = None,
673
+ position_ids: Optional[torch.LongTensor] = None,
674
+ past_key_values: Optional[Cache] = None,
675
+ inputs_embeds: Optional[torch.FloatTensor] = None,
676
+ use_cache: Optional[bool] = None,
677
+ output_attentions: Optional[bool] = None,
678
+ output_hidden_states: Optional[bool] = None,
679
+ cache_position: Optional[torch.LongTensor] = None,
680
+ **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
681
+ ) -> BaseModelOutputWithPast:
682
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
683
+ if use_cache and past_key_values is None:
684
+ # for the case where prepare_inputs_for_generation is not called to create the cache (as in fast-llm test)
685
+ batch_size = input_ids.shape[0] if input_ids is not None else inputs_embeds.shape[0]
686
+ past_key_values = HybridMambaAttentionDynamicCache(self.config, batch_size, self.dtype, device=self.device)
687
+ output = super().forward(
688
+ input_ids=input_ids,
689
+ attention_mask=attention_mask,
690
+ position_ids=position_ids,
691
+ past_key_values=past_key_values,
692
+ inputs_embeds=inputs_embeds,
693
+ use_cache=use_cache,
694
+ output_attentions=output_attentions,
695
+ output_hidden_states=output_hidden_states,
696
+ cache_position=cache_position,
697
+ **flash_attn_kwargs,
698
+ )
699
+ past_key_values: HybridMambaAttentionDynamicCache = output.past_key_values
700
+ if past_key_values and not past_key_values.has_previous_state:
701
+ past_key_values.has_previous_state = True
702
+ return output
703
+
704
+
705
+ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
706
+
707
+
708
+ class AprielThinkerSSMHybridPreTrainedModel(PreTrainedModel):
709
+ config_class = AprielHConfig
710
+ base_model_prefix = "model"
711
+ _no_split_modules = ["MistralDecoderLayer", "AprielSSMDecoderLayer", "AprielSSMM2DecoderLayer"]
712
+ _skip_keys_device_placement = ["past_key_values"]
713
+ _supports_flash_attn_2 = True
714
+ _supports_sdpa = True
715
+ _supports_flex_attn = True
716
+ _supports_cache_class = True
717
+ _supports_quantized_cache = True
718
+ _supports_static_cache = True
719
+ _supports_attention_backend = True
720
+
721
+ def _init_weights(self, module):
722
+ std = self.config.initializer_range
723
+ if isinstance(module, nn.Linear):
724
+ module.weight.data.normal_(mean=0.0, std=std)
725
+ if module.bias is not None:
726
+ module.bias.data.zero_()
727
+ elif isinstance(module, nn.Embedding):
728
+ module.weight.data.normal_(mean=0.0, std=std)
729
+ if module.padding_idx is not None:
730
+ module.weight.data[module.padding_idx].zero_()
731
+ elif isinstance(module, MistralRMSNorm):
732
+ module.weight.data.fill_(1.0)
733
+
734
+
735
+ class AprielHForCausalLM(AprielThinkerSSMHybridPreTrainedModel, GenerationMixin):
736
+ _tied_weights_keys = ["lm_head.weight"]
737
+ _tp_plan = {"lm_head": "colwise_rep"}
738
+
739
+ def __init__(self, config: AprielHConfig, **kwargs):
740
+ super().__init__(config, **kwargs)
741
+ self.model = AprielHModel(config)
742
+ self.vocab_size = config.vocab_size
743
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
744
+
745
+ # Initialize weights and apply final processing
746
+ self.post_init()
747
+
748
+ def get_input_embeddings(self):
749
+ return self.model.embed_tokens
750
+
751
+ def set_input_embeddings(self, value):
752
+ self.model.embed_tokens = value
753
+
754
+ def get_output_embeddings(self):
755
+ return self.lm_head
756
+
757
+ def set_output_embeddings(self, new_embeddings):
758
+ self.lm_head = new_embeddings
759
+
760
+ def set_decoder(self, decoder):
761
+ self.model = decoder
762
+
763
+ def get_decoder(self):
764
+ return self.model
765
+
766
+ def prepare_inputs_for_generation(
767
+ self,
768
+ input_ids,
769
+ past_key_values=None,
770
+ attention_mask=None,
771
+ inputs_embeds=None,
772
+ output_router_logits=False,
773
+ cache_position=None,
774
+ position_ids=None,
775
+ use_cache=True,
776
+ **kwargs,
777
+ ):
778
+ # Overwritten -- has a unique cache type, `HybridMambaAttentionDynamicCache`
779
+
780
+ empty_past_kv = past_key_values is None or not isinstance(past_key_values, HybridMambaAttentionDynamicCache)
781
+
782
+ # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
783
+ # Exception 1: when passing input_embeds, input_ids may be missing entries
784
+ # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
785
+ # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.
786
+ # (we can't check exception 3 while compiling)
787
+ if not empty_past_kv:
788
+ if inputs_embeds is not None or cache_position[-1] >= input_ids.shape[1]: # Exception 1 # Exception 3
789
+ input_ids = input_ids[:, -cache_position.shape[0] :]
790
+ elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
791
+ input_ids = input_ids[:, cache_position]
792
+ else:
793
+ past_key_values = HybridMambaAttentionDynamicCache(
794
+ self.config, input_ids.shape[0], self.dtype, device=self.device
795
+ )
796
+
797
+ if attention_mask is not None and position_ids is None:
798
+ # create position_ids on the fly for batch generation
799
+ position_ids = attention_mask.long().cumsum(-1) - 1
800
+ position_ids.masked_fill_(attention_mask == 0, 1)
801
+ if not empty_past_kv:
802
+ position_ids = position_ids[:, -input_ids.shape[1] :]
803
+
804
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
805
+ if inputs_embeds is not None and empty_past_kv:
806
+ model_inputs = {"inputs_embeds": inputs_embeds}
807
+ else:
808
+ model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases
809
+
810
+ model_inputs.update(
811
+ {
812
+ "position_ids": position_ids,
813
+ "past_key_values": past_key_values,
814
+ "use_cache": use_cache,
815
+ "attention_mask": attention_mask,
816
+ "output_router_logits": output_router_logits,
817
+ "cache_position": cache_position,
818
+ }
819
+ )
820
+ return model_inputs
821
+
822
+ def forward(
823
+ self,
824
+ input_ids: Optional[torch.LongTensor] = None,
825
+ attention_mask: Optional[torch.Tensor] = None,
826
+ position_ids: Optional[torch.LongTensor] = None,
827
+ past_key_values: Optional[Cache] = None,
828
+ inputs_embeds: Optional[torch.FloatTensor] = None,
829
+ labels: Optional[torch.LongTensor] = None,
830
+ use_cache: Optional[bool] = None,
831
+ output_attentions: Optional[bool] = None,
832
+ output_hidden_states: Optional[bool] = None,
833
+ cache_position: Optional[torch.LongTensor] = None,
834
+ logits_to_keep: Union[int, torch.Tensor] = 0,
835
+ **kwargs: Unpack[KwargsForCausalLM],
836
+ ) -> Union[tuple, CausalLMOutputWithPast]:
837
+ r"""
838
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
839
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
840
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
841
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
842
+
843
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
844
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
845
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
846
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
847
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
848
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
849
+
850
+ Returns:
851
+
852
+ Example:
853
+
854
+ ```python
855
+ >>> from transformers import AutoTokenizer, MistralForCausalLM
856
+
857
+ >>> model = MistralForCausalLM.from_pretrained("meta-mistral/Mistral-2-7b-hf")
858
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-mistral/Mistral-2-7b-hf")
859
+
860
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
861
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
862
+
863
+ >>> # Generate
864
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
865
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
866
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
867
+ ```"""
868
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
869
+ output_hidden_states = (
870
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
871
+ )
872
+
873
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
874
+ outputs: BaseModelOutputWithPast = self.model(
875
+ input_ids=input_ids,
876
+ attention_mask=attention_mask,
877
+ position_ids=position_ids,
878
+ past_key_values=past_key_values,
879
+ inputs_embeds=inputs_embeds,
880
+ use_cache=use_cache,
881
+ output_attentions=output_attentions,
882
+ output_hidden_states=output_hidden_states,
883
+ cache_position=cache_position,
884
+ mamba_mask=attention_mask, # non-expended mask
885
+ **kwargs,
886
+ )
887
+
888
+ hidden_states = outputs.last_hidden_state
889
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
890
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
891
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
892
+
893
+ loss = None
894
+ if labels is not None:
895
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
896
+
897
+ return AprielHybridCausalOutput(
898
+ loss=loss,
899
+ logits=logits,
900
+ all_hidden_states=outputs.hidden_states,
901
+ past_key_values=outputs.past_key_values,
902
+ )
903
+
904
+
905
+ __all__ = [
906
+ "AprielHForCausalLM",
907
+ "AprielHModel",
908
+ ]
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
The diff for this file is too large to render. See raw diff