Commit
·
dac8a18
0
Parent(s):
initial commit
Browse files- .gitattributes +36 -0
- README.md +279 -0
- config.json +104 -0
- configuration_apriel_h.py +37 -0
- images/apriel_h.png +3 -0
- images/apriel_h_vs_apriel_15b_eval_thrput_comparison.png +3 -0
- images/throughput_eval_score_vs_throughput_1-16k_annotated.png +3 -0
- model.safetensors.index.json +915 -0
- model_0.safetensors +3 -0
- model_1.safetensors +3 -0
- model_2.safetensors +3 -0
- model_3.safetensors +3 -0
- modeling_apriel_h.py +908 -0
- tokenizer.json +0 -0
- tokenizer_config.json +0 -0
.gitattributes
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
|
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: mit
|
| 3 |
+
pipeline_tag: text-generation
|
| 4 |
+
library_name: transformers
|
| 5 |
+
track_downloads: true
|
| 6 |
+
---
|
| 7 |
+
|
| 8 |
+
# Apriel-H1-15b-Thinker
|
| 9 |
+
|
| 10 |
+
[](https://huggingface.co/docs/transformers/index)
|
| 11 |
+
|
| 12 |
+
<img src="images/apriel_h.png" width="120" alt="thumbnail"/> `/ˈɑː.pri.əl/`
|
| 13 |
+
|
| 14 |
+
A 15B-parameter **hybrid reasoning model** combining Transformer attention and Mamba State Space layers for high efficiency and scalability. Derived from *Apriel-Nemotron-15B-Thinker* through progressive distillation, **Apriel-H1** replaces less critical attention layers with linear Mamba blocks—achieving **over 2× higher inference throughput** in vLLM with **minimal loss in reasoning, math, and coding performance**.
|
| 15 |
+
|
| 16 |
+
- **Model Size:** 15B parameters
|
| 17 |
+
- **Context Length:** 65K (target; runtime dependent)
|
| 18 |
+
- **Languages:** English (best)
|
| 19 |
+
|
| 20 |
+
## Highlights
|
| 21 |
+
|
| 22 |
+
- Hybrid Transformer–SSM architecture
|
| 23 |
+
- ~2× throughput improvement over the base Thinker model
|
| 24 |
+
- Retains strong reasoning, math, and coding capabilities
|
| 25 |
+
- Built via efficient distillation—no training from scratch required
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
## Model Overview
|
| 29 |
+
Apriel-H1-15b-Thinker is designed for agentic tasks, code assistance, and multi-step reasoning. It follows Apriel’s “think then answer” style: the model first produces a hidden chain-of-thought and then a concise final response. Where reasoning traces are undesired, configure prompts to favor concise outputs.
|
| 30 |
+
|
| 31 |
+
**Technical report**: <a href="https://github.com/ServiceNow/apriel/blob/main/assets/Apriel-H1.pdf" target="_blank" rel="noopener noreferrer">Apriel-H1 Report</a>
|
| 32 |
+
|
| 33 |
+
### Efficient and strong among hybrids
|
| 34 |
+

|
| 35 |
+
|
| 36 |
+
All models were evaluated with vllm server endpoints using FlashInfer (except for AI21-Jamba-Reasoning-3B which used FlashAttention2), mamba_cache was set to fp32 for models: NVIDIA-Nemotron-Nano-9B-v2 and AI21-Jamba-Reasoning-3B.
|
| 37 |
+
|
| 38 |
+
#### Comparing with Thinker ~2x speedup!
|
| 39 |
+
<img src="images/apriel_h_vs_apriel_15b_eval_thrput_comparison.png" width='auto'>
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
## How to Use
|
| 43 |
+
|
| 44 |
+
Install dependencies:
|
| 45 |
+
|
| 46 |
+
```bash
|
| 47 |
+
pip install transformers==4.53.2
|
| 48 |
+
```
|
| 49 |
+
|
| 50 |
+
Basic usage with Transformers generate:
|
| 51 |
+
|
| 52 |
+
```python
|
| 53 |
+
import re
|
| 54 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 55 |
+
|
| 56 |
+
model_name = "ServiceNow-AI/Apriel-H1-15b-Thinker-SFT"
|
| 57 |
+
|
| 58 |
+
# load the tokenizer and the model
|
| 59 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 60 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 61 |
+
model_name,
|
| 62 |
+
torch_dtype="auto",
|
| 63 |
+
device_map="auto"
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
prompt = "Positive real numbers $x$ and $y$ satisfy $y^3=x^2$ and $(y-x)^2=4y^2$. What is $x+y$?\nMark your solution with \\boxed"
|
| 67 |
+
messages = [
|
| 68 |
+
{"role": "user", "content": prompt}
|
| 69 |
+
]
|
| 70 |
+
|
| 71 |
+
text = tokenizer.apply_chat_template(
|
| 72 |
+
messages,
|
| 73 |
+
tokenize=False,
|
| 74 |
+
add_generation_prompt=True,
|
| 75 |
+
tools=[]
|
| 76 |
+
)
|
| 77 |
+
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
|
| 78 |
+
|
| 79 |
+
generated_ids = model.generate(**model_inputs, max_new_tokens=1024)
|
| 80 |
+
output = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
| 81 |
+
|
| 82 |
+
response = re.findall(r"\\[BEGIN FINAL RESPONSE\\](.*?)\\[END FINAL RESPONSE\\]", output, re.DOTALL)[0].strip()
|
| 83 |
+
print("response:", response)
|
| 84 |
+
```
|
| 85 |
+
|
| 86 |
+
Recommended settings: temperature 0.6; increase `max_new_tokens` for complex reasoning.
|
| 87 |
+
|
| 88 |
+
## Use it with vLLM
|
| 89 |
+
|
| 90 |
+
### 💻 Local Installation
|
| 91 |
+
|
| 92 |
+
#### 1. Create and activate a Python environment
|
| 93 |
+
You can use any environment manager. The example below uses [`uv`](https://github.com/astral-sh/uv):
|
| 94 |
+
|
| 95 |
+
```bash
|
| 96 |
+
uv venv --python 3.12 --seed
|
| 97 |
+
source .venv/bin/activate
|
| 98 |
+
```
|
| 99 |
+
|
| 100 |
+
#### 2. Install vLLM and the Apriel plugin
|
| 101 |
+
|
| 102 |
+
Find our plugin at [https://github.com/ServiceNow/apriel](https://github.com/ServiceNow/apriel).
|
| 103 |
+
You may need to install a version of **vLLM** compatible with your **CUDA** version.
|
| 104 |
+
|
| 105 |
+
In this example, we use the default CUDA version and let vLLM automatically select the correct backend.
|
| 106 |
+
|
| 107 |
+
```bash
|
| 108 |
+
git clone [email protected]:ServiceNow/apriel.git
|
| 109 |
+
cd apriel
|
| 110 |
+
uv pip install vllm==0.10.2 --torch-backend=auto
|
| 111 |
+
pip install .
|
| 112 |
+
```
|
| 113 |
+
|
| 114 |
+
<!--
|
| 115 |
+
### 🐳 Use Prebuilt Docker Image
|
| 116 |
+
|
| 117 |
+
For convenience, prebuilt Docker images are available from **GitHub Container Registry (GHCR):**
|
| 118 |
+
|
| 119 |
+
**Repository:**
|
| 120 |
+
[https://github.com/ServiceNow/apriel](https://github.com/ServiceNow/apriel)
|
| 121 |
+
|
| 122 |
+
**Container packages:**
|
| 123 |
+
[https://github.com/ServiceNow/apriel/pkgs/container/apriel](https://github.com/ServiceNow/apriel/pkgs/container/apriel)
|
| 124 |
+
|
| 125 |
+
Pull the latest version or a specific tagged build (e.g., commit SHA):
|
| 126 |
+
|
| 127 |
+
```bash
|
| 128 |
+
docker pull ghcr.io/servicenow/apriel:latest
|
| 129 |
+
# or a specific digest
|
| 130 |
+
docker pull ghcr.io/servicenow/apriel:sha-e41528d
|
| 131 |
+
|
| 132 |
+
```
|
| 133 |
+
-->
|
| 134 |
+
### 🧠 Running a vLLM Server
|
| 135 |
+
|
| 136 |
+
### Option 1: Run locally (from source install)
|
| 137 |
+
|
| 138 |
+
Once installed, you can launch a vLLM OpenAI-compatible API server with your Apriel model:
|
| 139 |
+
|
| 140 |
+
```bash
|
| 141 |
+
vllm serve \
|
| 142 |
+
--model ServiceNow-AI/Apriel-H1-15b-Thinker-SFT \
|
| 143 |
+
--port 8000
|
| 144 |
+
```
|
| 145 |
+
|
| 146 |
+
#### Option 2: Run via Docker
|
| 147 |
+
|
| 148 |
+
You can run the server directly using the prebuilt container:
|
| 149 |
+
|
| 150 |
+
```bash
|
| 151 |
+
docker run --runtime nvidia --gpus all \
|
| 152 |
+
-v ~/.cache/huggingface:/root/.cache/huggingface \
|
| 153 |
+
--env "HUGGING_FACE_HUB_TOKEN=$HF_TOKEN" \
|
| 154 |
+
-p 8000:8000 \
|
| 155 |
+
--ipc=host \
|
| 156 |
+
ghcr.io/servicenow/apriel:latest \
|
| 157 |
+
--model ServiceNow-AI/Apriel-H1-15b-Thinker-SFT \
|
| 158 |
+
```
|
| 159 |
+
|
| 160 |
+
## Chat Template
|
| 161 |
+
|
| 162 |
+
```
|
| 163 |
+
<|system|>
|
| 164 |
+
You are a thoughtful and systematic AI assistant built by ServiceNow Language Models (SLAM) lab. Before providing an answer, analyze the problem carefully and present your reasoning step by step. After explaining your thought process, provide the final solution in the following format: [BEGIN FINAL RESPONSE] ... [END FINAL RESPONSE].
|
| 165 |
+
<|end|>
|
| 166 |
+
<|user|>
|
| 167 |
+
# user message here
|
| 168 |
+
<|end|>
|
| 169 |
+
<|assistant|>
|
| 170 |
+
Here are my reasoning steps:
|
| 171 |
+
# thoughts here
|
| 172 |
+
[BEGIN FINAL RESPONSE]
|
| 173 |
+
# assistant response here
|
| 174 |
+
[END FINAL RESPONSE]
|
| 175 |
+
<|end|>
|
| 176 |
+
```
|
| 177 |
+
The model will first generate its thinking process and then generate its final response between `[BEGIN FINAL RESPONSE]` and `[END FINAL RESPONSE]`. Here is a code snippet demonstrating the application of the chat template:
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
```python
|
| 182 |
+
from transformers import AutoTokenizer
|
| 183 |
+
model_name = "ServiceNow-AI/Apriel-H1-15b-Thinker-SFT"
|
| 184 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 185 |
+
|
| 186 |
+
# prepare the model input
|
| 187 |
+
custom_system_prompt = "Answer like a pirate."
|
| 188 |
+
prompt = "You are an expert assistant in the implementation of customer experience management aspect of retail applications \n \nYou will be using Python as the programming language. \n \nYou will utilize a factory design pattern for the implementation and following the dependency inversion principle \n \nYou will modify the implementation based on user requirements. \n \nUpon user request, you will add, update, and remove the features & enhancements in the implementation provided by you. \n \nYou will ask whether the user wants to refactor the provided code or needs a sample implementation for reference. Upon user confirmation, I will proceed accordingly. \n \n**Guidelines:** \n 1. **User Requirements:** \n - You have to ask users about their requirements, clarify the user expectations, and suggest the best possible solution by providing examples of Python code snippets. \n - Ask users about which type of reports they need to assess the AI model's performance, accuracy, and reliability. \n - After providing the solution, you have to ask the user about the trial of the solution and modify the solution based on the user feedback. \n \n 2. **Libraries/Frameworks:** \n - You will be utilizing Python as a programming language. \n - You will be using Flask framework for REST APIS implementation \n \n 3. **Communication Gesture:** \n - Your conversation with the user should be interactive, supportive, courageous, and professional. \n - You have to break down the complex concepts into sub-concepts and try to explain them to the user. \n - You have to ask the user for the required parameters. If the user refuses to provide in 2 attempts, politely exit the conversation. \n - You have to provide your supported parameters to the user, if the user refuses to accept them then you have to put an apology note and exit the conversation. \n - You have to track the conversation about unasked questions by the user. If some/one of the questions remain then you have to remind the user about these questions and proceed to answer them based on the user's confirmation \n \n 4. **Implementation:** \n - Your code/implementations should be reliable, scalable, modular, and reusable. \n - You will be providing unit tests for the implementation upon user request. \n - You will be following MVC architecture for the applications \n - Your implementations must be well-commented and readable \n \n \n- Today's date is 23rd August 2024. \n- The default sender email is [email protected].\nHi, I am conducting research on retail customer feedback systems and I need assistance with designing and implementing them. Could you kindly provide me with a list of general customer feedback system modules?"
|
| 189 |
+
messages = [
|
| 190 |
+
{"role": "user", "content": custom_system_prompt + "\n\n" + prompt}
|
| 191 |
+
]
|
| 192 |
+
# example tools
|
| 193 |
+
tools = [{"type": "function", "function": {"name": "getRetailFeedbackModules", "description": "Returns the list of modules usually present in the retail industry", "parameters": {"type": "object", "properties": {"page": {"type": "integer", "description": "The current page number.", "default": 1}, "page_size": {"type": "integer", "description": "The number of items per page.", "default": 3}}}}}, {"type": "function", "function": {"name": "verifyImplementation", "description": "Returns the list of modules usually present in the retail industry", "parameters": {"type": "object", "properties": {"coding_language": {"type": "string", "description": "The supported languages for verification of implementation.", "default": "python", "enum": ["python", "java", "php"]}, "code": {"type": "string", "description": "The code which needs verification"}, "design_pattern": {"type": "string", "description": "The design pattern to verify in the implementation", "enum": ["factory", "strategy", "singleton"]}, "verify_best_practices": {"type": "boolean", "description": "The verification of the coding style based on the language selected", "default": true}}}}}]
|
| 194 |
+
text = tokenizer.apply_chat_template(
|
| 195 |
+
messages,
|
| 196 |
+
tokenize=False,
|
| 197 |
+
add_generation_prompt=True,
|
| 198 |
+
tools=tools
|
| 199 |
+
)
|
| 200 |
+
model_inputs = tokenizer([text], return_tensors="pt")
|
| 201 |
+
```
|
| 202 |
+
|
| 203 |
+
### Usage Guidelines
|
| 204 |
+
1. Use the model’s default chat template, which already includes a system prompt. We recommend adding all other instructions within the user message.
|
| 205 |
+
2. We recommend setting temperature to `0.6`.
|
| 206 |
+
3. We ensure the model starts with `Here are my reasoning steps:\n` during all our evaluations. This is implemented in the default chat template.
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
## Intended Use
|
| 211 |
+
|
| 212 |
+
The Apriel family of models are designed for a variety of general-purpose instruction tasks, including:
|
| 213 |
+
|
| 214 |
+
- Code assistance and generation
|
| 215 |
+
- Logical reasoning and multi-step tasks
|
| 216 |
+
- Question answering and information retrieval
|
| 217 |
+
- Function calling, complex instruction following and agent use cases
|
| 218 |
+
|
| 219 |
+
They are **not intended** for use in safety-critical applications without human oversight or in scenarios requiring guaranteed factual accuracy.
|
| 220 |
+
|
| 221 |
+
---
|
| 222 |
+
|
| 223 |
+
## Limitations
|
| 224 |
+
|
| 225 |
+
- **Factual accuracy:** May produce incorrect, misleading, or outdated content. Outputs should be verified before use in critical contexts.
|
| 226 |
+
- **Bias:** May reflect societal, cultural, or systemic biases present in training data.
|
| 227 |
+
- **Ethics:** Do not use the model to produce harmful, unlawful, or unethical content.
|
| 228 |
+
- **Language:** Strongest performance is in English. Output quality may degrade in underrepresented languages.
|
| 229 |
+
- **Critical use:** Not suitable for medical, legal, financial, or other high-risk applications without safeguards.
|
| 230 |
+
|
| 231 |
+
---
|
| 232 |
+
|
| 233 |
+
## Security and Responsible Use
|
| 234 |
+
|
| 235 |
+
**Security Responsibilities:**
|
| 236 |
+
Deployers and users are strongly encouraged to align their security practices with established frameworks and regulatory guidelines such as the EU AI Act and the NIST AI Risk Management Framework (RMF).
|
| 237 |
+
|
| 238 |
+
**Guidelines for Deployers:**
|
| 239 |
+
|
| 240 |
+
- Regularly conduct robustness assessments to identify and mitigate adversarial inputs.
|
| 241 |
+
- Implement validation and filtering processes to prevent harmful or biased outputs.
|
| 242 |
+
- Continuously perform data privacy checks to guard against unintended data leaks.
|
| 243 |
+
- Document and communicate the model's limitations, intended usage, and known security risks to all end-users.
|
| 244 |
+
- Schedule periodic security reviews and updates to address emerging threats and vulnerabilities.
|
| 245 |
+
|
| 246 |
+
**Guidelines for Users:**
|
| 247 |
+
|
| 248 |
+
- Follow established security policies and usage guidelines provided by deployers.
|
| 249 |
+
- Protect and manage sensitive information when interacting with the model.
|
| 250 |
+
- Report anomalies, suspicious behavior, or unsafe outputs to deployers or developers.
|
| 251 |
+
- Maintain human oversight and apply judgment to mitigate potential security or ethical risks during interactions.
|
| 252 |
+
|
| 253 |
+
**Disclaimer:**
|
| 254 |
+
Users accept responsibility for securely deploying, managing, and using this open-source LLM. The model is provided "as-is," without explicit or implied warranty regarding security or fitness for any specific application or environment.
|
| 255 |
+
|
| 256 |
+
---
|
| 257 |
+
|
| 258 |
+
## Software
|
| 259 |
+
|
| 260 |
+
- **Training stack:** [Fast-LLM](https://github.com/ServiceNow/Fast-LLM)
|
| 261 |
+
|
| 262 |
+
---
|
| 263 |
+
|
| 264 |
+
## License
|
| 265 |
+
|
| 266 |
+
MIT
|
| 267 |
+
|
| 268 |
+
---
|
| 269 |
+
|
| 270 |
+
## Citation
|
| 271 |
+
|
| 272 |
+
```bibtex
|
| 273 |
+
@misc{apriel_h1_2025,
|
| 274 |
+
title = {Apriel-H1: Towards Efficient Enterprise Reasoning Models},
|
| 275 |
+
author = {ServiceNow Language Models Lab},
|
| 276 |
+
howpublished = {https://huggingface.co/ServiceNow-AI/Apriel-H1-15b-Thinker-SFT},
|
| 277 |
+
year = {2025}
|
| 278 |
+
}
|
| 279 |
+
```
|
config.json
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"AprielHForCausalLM"
|
| 4 |
+
],
|
| 5 |
+
"attention_dropout": 0.0,
|
| 6 |
+
"auto_map": {
|
| 7 |
+
"AutoConfig": "configuration_apriel_h.AprielHConfig",
|
| 8 |
+
"AutoModel": "modeling_apriel_h.AprielHModel",
|
| 9 |
+
"AutoModelForCausalLM": "modeling_apriel_h.AprielHForCausalLM"
|
| 10 |
+
},
|
| 11 |
+
"bos_token_id": 1,
|
| 12 |
+
"eos_token_id": 2,
|
| 13 |
+
"head_dim": 128,
|
| 14 |
+
"hidden_act": "silu",
|
| 15 |
+
"hidden_size": 5120,
|
| 16 |
+
"hybrid_block_layout": [
|
| 17 |
+
"t",
|
| 18 |
+
"t",
|
| 19 |
+
"t",
|
| 20 |
+
"m2",
|
| 21 |
+
"t",
|
| 22 |
+
"m2",
|
| 23 |
+
"m2",
|
| 24 |
+
"m2",
|
| 25 |
+
"m2",
|
| 26 |
+
"t",
|
| 27 |
+
"t",
|
| 28 |
+
"t",
|
| 29 |
+
"t",
|
| 30 |
+
"t",
|
| 31 |
+
"t",
|
| 32 |
+
"t",
|
| 33 |
+
"m2",
|
| 34 |
+
"t",
|
| 35 |
+
"m2",
|
| 36 |
+
"m2",
|
| 37 |
+
"m2",
|
| 38 |
+
"m2",
|
| 39 |
+
"m2",
|
| 40 |
+
"m2",
|
| 41 |
+
"m2",
|
| 42 |
+
"m2",
|
| 43 |
+
"m2",
|
| 44 |
+
"m2",
|
| 45 |
+
"t",
|
| 46 |
+
"t",
|
| 47 |
+
"m2",
|
| 48 |
+
"m2",
|
| 49 |
+
"m2",
|
| 50 |
+
"m2",
|
| 51 |
+
"t",
|
| 52 |
+
"m2",
|
| 53 |
+
"m2",
|
| 54 |
+
"m2",
|
| 55 |
+
"m2",
|
| 56 |
+
"m2",
|
| 57 |
+
"m2",
|
| 58 |
+
"m2",
|
| 59 |
+
"m2",
|
| 60 |
+
"m2",
|
| 61 |
+
"m2",
|
| 62 |
+
"m2",
|
| 63 |
+
"m2",
|
| 64 |
+
"m2",
|
| 65 |
+
"t",
|
| 66 |
+
"m2"
|
| 67 |
+
],
|
| 68 |
+
"initializer_range": 0.02,
|
| 69 |
+
"intermediate_size": 14336,
|
| 70 |
+
"max_position_embeddings": 65536,
|
| 71 |
+
"model_type": "apriel_h",
|
| 72 |
+
"num_attention_heads": 32,
|
| 73 |
+
"num_hidden_layers": 50,
|
| 74 |
+
"num_key_value_heads": 8,
|
| 75 |
+
"rms_norm_eps": 1e-05,
|
| 76 |
+
"rope_scaling": {
|
| 77 |
+
"rope_type": "default"
|
| 78 |
+
},
|
| 79 |
+
"rope_theta": 1000000.0,
|
| 80 |
+
"sliding_window": null,
|
| 81 |
+
"ssm_cfg": {
|
| 82 |
+
"activation": "silu",
|
| 83 |
+
"bias": false,
|
| 84 |
+
"chunk_size": 128,
|
| 85 |
+
"conv_bias": true,
|
| 86 |
+
"d_conv": 4,
|
| 87 |
+
"d_inner": 4096,
|
| 88 |
+
"d_state": 16,
|
| 89 |
+
"d_xb": 1024,
|
| 90 |
+
"dt_init": "random",
|
| 91 |
+
"dt_init_floor": 0.0001,
|
| 92 |
+
"dt_max": 0.1,
|
| 93 |
+
"dt_min": 0.001,
|
| 94 |
+
"dt_rank": 320,
|
| 95 |
+
"dt_scale": 1.0,
|
| 96 |
+
"expand": 1,
|
| 97 |
+
"n_qk_heads": 32,
|
| 98 |
+
"n_v_heads": 32
|
| 99 |
+
},
|
| 100 |
+
"tie_word_embeddings": false,
|
| 101 |
+
"transformers_version": "4.53.2",
|
| 102 |
+
"use_cache": true,
|
| 103 |
+
"vocab_size": 131072
|
| 104 |
+
}
|
configuration_apriel_h.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import MistralConfig
|
| 2 |
+
from transformers.utils import logging
|
| 3 |
+
|
| 4 |
+
logger = logging.get_logger(__name__)
|
| 5 |
+
|
| 6 |
+
ssm_config_default = {
|
| 7 |
+
"d_state": 64,
|
| 8 |
+
"n_qk_heads": 32,
|
| 9 |
+
"expand": 1,
|
| 10 |
+
"chunk_size": 128,
|
| 11 |
+
"activation": "identity",
|
| 12 |
+
"bias": False,
|
| 13 |
+
"d_conv": 4,
|
| 14 |
+
"d_inner": 32 * 128,
|
| 15 |
+
"d_xb": None, # will be set to model dim
|
| 16 |
+
"dt_rank": "auto",
|
| 17 |
+
"dt_min": 0.001,
|
| 18 |
+
"dt_max": 0.1,
|
| 19 |
+
"dt_init": "random",
|
| 20 |
+
"dt_scale": 1.0,
|
| 21 |
+
"dt_init_floor": 1e-4,
|
| 22 |
+
"conv_bias": True,
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class AprielHConfig(MistralConfig):
|
| 27 |
+
model_type = "apriel_h"
|
| 28 |
+
|
| 29 |
+
def __init__(self, hybrid_block_layout=["m2"], ssm_cfg=None, **kwargs):
|
| 30 |
+
super().__init__(**kwargs)
|
| 31 |
+
self.hybrid_block_layout = hybrid_block_layout
|
| 32 |
+
self.head_dim = self.head_dim or self.hidden_size // self.num_attention_heads # as in transformers 4.51.3
|
| 33 |
+
self.ssm_cfg = ssm_cfg or ssm_config_default
|
| 34 |
+
|
| 35 |
+
for k, v in ssm_config_default.items():
|
| 36 |
+
if k not in self.ssm_cfg:
|
| 37 |
+
self.ssm_cfg[k] = v # to make sure all elements are present in the config
|
images/apriel_h.png
ADDED
|
Git LFS Details
|
images/apriel_h_vs_apriel_15b_eval_thrput_comparison.png
ADDED
|
Git LFS Details
|
images/throughput_eval_score_vs_throughput_1-16k_annotated.png
ADDED
|
Git LFS Details
|
model.safetensors.index.json
ADDED
|
@@ -0,0 +1,915 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"metadata": {
|
| 3 |
+
"fast_llm_metadata": {
|
| 4 |
+
"fast_llm_version": "0.2.0",
|
| 5 |
+
"model": "hybrid_ssm",
|
| 6 |
+
"format": "apriel_ssm_thinker_hybrid",
|
| 7 |
+
"config": {
|
| 8 |
+
"type": "hybrid_ssm",
|
| 9 |
+
"base_model": {
|
| 10 |
+
"transformer": {
|
| 11 |
+
"type": "lm_decoder",
|
| 12 |
+
"normalization": {
|
| 13 |
+
"type": "rms_norm",
|
| 14 |
+
"epsilon": 1e-05
|
| 15 |
+
},
|
| 16 |
+
"rotary": {
|
| 17 |
+
"type": "default",
|
| 18 |
+
"theta": 1000000.0
|
| 19 |
+
},
|
| 20 |
+
"peft": {
|
| 21 |
+
"type": "none"
|
| 22 |
+
},
|
| 23 |
+
"num_layers": 50,
|
| 24 |
+
"hidden_size": 5120,
|
| 25 |
+
"num_attention_heads": 32,
|
| 26 |
+
"head_groups": 8,
|
| 27 |
+
"add_linear_biases": false,
|
| 28 |
+
"ffn_hidden_size": 14336,
|
| 29 |
+
"kv_channels": 128,
|
| 30 |
+
"gated": true,
|
| 31 |
+
"activation_type": "silu",
|
| 32 |
+
"mlp_lr_scale": 1.0,
|
| 33 |
+
"attention_lr_scale": 1.0
|
| 34 |
+
},
|
| 35 |
+
"vision_encoder": {
|
| 36 |
+
"transformer": {
|
| 37 |
+
"normalization": {
|
| 38 |
+
"type": "layer_norm"
|
| 39 |
+
},
|
| 40 |
+
"rotary": {
|
| 41 |
+
"type": "none"
|
| 42 |
+
},
|
| 43 |
+
"peft": {
|
| 44 |
+
"type": "none"
|
| 45 |
+
}
|
| 46 |
+
},
|
| 47 |
+
"patch_norm": {
|
| 48 |
+
"type": "layer_norm"
|
| 49 |
+
}
|
| 50 |
+
},
|
| 51 |
+
"vocab_size": 131072,
|
| 52 |
+
"use_position_embeddings": false,
|
| 53 |
+
"tie_word_embeddings": false,
|
| 54 |
+
"cross_entropy_impl": "fused",
|
| 55 |
+
"distillation_loss_implementation": "reverse_kl",
|
| 56 |
+
"distillation_model": "teacher",
|
| 57 |
+
"parallel_embeddings": false,
|
| 58 |
+
"embeddings_lr_scale": 1.0,
|
| 59 |
+
"output_lr_scale": 1.0,
|
| 60 |
+
"ssm": {
|
| 61 |
+
"normalization": {
|
| 62 |
+
"type": "layer_norm"
|
| 63 |
+
},
|
| 64 |
+
"expansion_factor": 1,
|
| 65 |
+
"state_size": 16,
|
| 66 |
+
"conv_kernel_dimension": 4,
|
| 67 |
+
"dt_rank": 320,
|
| 68 |
+
"n_qk_heads": 32,
|
| 69 |
+
"n_v_heads": 32,
|
| 70 |
+
"d_inner": 4096,
|
| 71 |
+
"d_xb": 1024,
|
| 72 |
+
"add_bias_linear": false,
|
| 73 |
+
"activation_type": "silu",
|
| 74 |
+
"chunk_size": 128,
|
| 75 |
+
"dt_init": "random",
|
| 76 |
+
"dt_scale": 1.0,
|
| 77 |
+
"dt_min": 0.001,
|
| 78 |
+
"dt_max": 0.1,
|
| 79 |
+
"dt_init_floor": 0.0001
|
| 80 |
+
},
|
| 81 |
+
"hybrid_block_layout": [
|
| 82 |
+
"t",
|
| 83 |
+
"t",
|
| 84 |
+
"t",
|
| 85 |
+
"m2",
|
| 86 |
+
"t",
|
| 87 |
+
"m2",
|
| 88 |
+
"m2",
|
| 89 |
+
"m2",
|
| 90 |
+
"m2",
|
| 91 |
+
"t",
|
| 92 |
+
"t",
|
| 93 |
+
"t",
|
| 94 |
+
"t",
|
| 95 |
+
"t",
|
| 96 |
+
"t",
|
| 97 |
+
"t",
|
| 98 |
+
"m2",
|
| 99 |
+
"t",
|
| 100 |
+
"m2",
|
| 101 |
+
"m2",
|
| 102 |
+
"m2",
|
| 103 |
+
"m2",
|
| 104 |
+
"m2",
|
| 105 |
+
"m2",
|
| 106 |
+
"m2",
|
| 107 |
+
"m2",
|
| 108 |
+
"m2",
|
| 109 |
+
"m2",
|
| 110 |
+
"t",
|
| 111 |
+
"t",
|
| 112 |
+
"m2",
|
| 113 |
+
"m2",
|
| 114 |
+
"m2",
|
| 115 |
+
"m2",
|
| 116 |
+
"t",
|
| 117 |
+
"m2",
|
| 118 |
+
"m2",
|
| 119 |
+
"m2",
|
| 120 |
+
"m2",
|
| 121 |
+
"m2",
|
| 122 |
+
"m2",
|
| 123 |
+
"m2",
|
| 124 |
+
"m2",
|
| 125 |
+
"m2",
|
| 126 |
+
"m2",
|
| 127 |
+
"m2",
|
| 128 |
+
"m2",
|
| 129 |
+
"m2",
|
| 130 |
+
"t",
|
| 131 |
+
"m2"
|
| 132 |
+
]
|
| 133 |
+
},
|
| 134 |
+
"multi_stage": {
|
| 135 |
+
"zero_stage": 3
|
| 136 |
+
},
|
| 137 |
+
"distributed": {
|
| 138 |
+
"tensor_parallel": 8,
|
| 139 |
+
"sequence_tensor_parallel": true,
|
| 140 |
+
"world_size": 64,
|
| 141 |
+
"rank": 0,
|
| 142 |
+
"local_world_size": 8,
|
| 143 |
+
"timeout": 36000.0,
|
| 144 |
+
"seed": 984060,
|
| 145 |
+
"training_dtype": "bfloat16"
|
| 146 |
+
}
|
| 147 |
+
},
|
| 148 |
+
"shards": [
|
| 149 |
+
"weights"
|
| 150 |
+
],
|
| 151 |
+
"metadata": {
|
| 152 |
+
"optimizer": {
|
| 153 |
+
"current_step": 25000,
|
| 154 |
+
"grad_scaler": {
|
| 155 |
+
"type": "NoopGradScaler"
|
| 156 |
+
}
|
| 157 |
+
},
|
| 158 |
+
"completed_steps": 25000,
|
| 159 |
+
"metrics": {
|
| 160 |
+
"Training": {
|
| 161 |
+
"train_iters": 60000,
|
| 162 |
+
"batch_size": 64,
|
| 163 |
+
"iteration": 25000,
|
| 164 |
+
"distillation_loss": 0.034176599606871604,
|
| 165 |
+
"language_model_loss": 0.034176599606871604,
|
| 166 |
+
"consumed_samples": 1600000,
|
| 167 |
+
"consumed_tokens": 26214400000,
|
| 168 |
+
"step_time_ms": 13669.109502492938,
|
| 169 |
+
"step_time_average_ms": 14812.520303467,
|
| 170 |
+
"remaining_time": 518438.21062134503,
|
| 171 |
+
"completion_time": 1757381892.6039205,
|
| 172 |
+
"percent_done": 41.666666666666664,
|
| 173 |
+
"skipped_iters": 0,
|
| 174 |
+
"nan_iters": 0,
|
| 175 |
+
"model_tflops": 126.99100713435496,
|
| 176 |
+
"hardware_tflops": 131.01289176252138,
|
| 177 |
+
"tokens_per_sec_per_gpu": 1198.6150229473199,
|
| 178 |
+
"run": 0,
|
| 179 |
+
"grad_norm": 0.24984633922576904,
|
| 180 |
+
"learning_rate": 2.968135593220339e-06,
|
| 181 |
+
"loss_scale": 1.0,
|
| 182 |
+
"reserved": 55520.0,
|
| 183 |
+
"allocated": 17871.4111328125,
|
| 184 |
+
"max_allocated": 48712.857421875,
|
| 185 |
+
"max_reserved": 55520.0,
|
| 186 |
+
"global_max_reserved": 55520.0
|
| 187 |
+
}
|
| 188 |
+
}
|
| 189 |
+
}
|
| 190 |
+
},
|
| 191 |
+
"model_config": {
|
| 192 |
+
"model_type": "apriel_ssm_thinker_hybrid",
|
| 193 |
+
"architectures": [
|
| 194 |
+
"AprielThinkerSSMHybridForCausalLM"
|
| 195 |
+
],
|
| 196 |
+
"rope_theta": 1000000.0,
|
| 197 |
+
"hidden_act": "silu",
|
| 198 |
+
"num_hidden_layers": 50,
|
| 199 |
+
"hidden_size": 5120,
|
| 200 |
+
"num_attention_heads": 32,
|
| 201 |
+
"num_key_value_heads": 8,
|
| 202 |
+
"intermediate_size": 14336,
|
| 203 |
+
"vocab_size": 131072,
|
| 204 |
+
"tie_word_embeddings": false,
|
| 205 |
+
"rms_norm_eps": 1e-05,
|
| 206 |
+
"head_dim": 128,
|
| 207 |
+
"rope_scaling": {
|
| 208 |
+
"rope_type": "default"
|
| 209 |
+
},
|
| 210 |
+
"ssm_cfg": {
|
| 211 |
+
"d_state": 16,
|
| 212 |
+
"n_v_heads": 32,
|
| 213 |
+
"n_qk_heads": 32,
|
| 214 |
+
"expand": 1,
|
| 215 |
+
"chunk_size": 128,
|
| 216 |
+
"bias": false,
|
| 217 |
+
"activation": "silu",
|
| 218 |
+
"dt_rank": 320,
|
| 219 |
+
"dt_min": 0.001,
|
| 220 |
+
"dt_max": 0.1,
|
| 221 |
+
"dt_init_floor": 0.0001,
|
| 222 |
+
"dt_scale": 1.0,
|
| 223 |
+
"d_xb": 1024,
|
| 224 |
+
"d_conv": 4,
|
| 225 |
+
"dt_init": "random",
|
| 226 |
+
"d_inner": 4096,
|
| 227 |
+
"conv_bias": true
|
| 228 |
+
},
|
| 229 |
+
"hybrid_block_layout": [
|
| 230 |
+
"t",
|
| 231 |
+
"t",
|
| 232 |
+
"t",
|
| 233 |
+
"m2",
|
| 234 |
+
"t",
|
| 235 |
+
"m2",
|
| 236 |
+
"m2",
|
| 237 |
+
"m2",
|
| 238 |
+
"m2",
|
| 239 |
+
"t",
|
| 240 |
+
"t",
|
| 241 |
+
"t",
|
| 242 |
+
"t",
|
| 243 |
+
"t",
|
| 244 |
+
"t",
|
| 245 |
+
"t",
|
| 246 |
+
"m2",
|
| 247 |
+
"t",
|
| 248 |
+
"m2",
|
| 249 |
+
"m2",
|
| 250 |
+
"m2",
|
| 251 |
+
"m2",
|
| 252 |
+
"m2",
|
| 253 |
+
"m2",
|
| 254 |
+
"m2",
|
| 255 |
+
"m2",
|
| 256 |
+
"m2",
|
| 257 |
+
"m2",
|
| 258 |
+
"t",
|
| 259 |
+
"t",
|
| 260 |
+
"m2",
|
| 261 |
+
"m2",
|
| 262 |
+
"m2",
|
| 263 |
+
"m2",
|
| 264 |
+
"t",
|
| 265 |
+
"m2",
|
| 266 |
+
"m2",
|
| 267 |
+
"m2",
|
| 268 |
+
"m2",
|
| 269 |
+
"m2",
|
| 270 |
+
"m2",
|
| 271 |
+
"m2",
|
| 272 |
+
"m2",
|
| 273 |
+
"m2",
|
| 274 |
+
"m2",
|
| 275 |
+
"m2",
|
| 276 |
+
"m2",
|
| 277 |
+
"m2",
|
| 278 |
+
"t",
|
| 279 |
+
"m2"
|
| 280 |
+
],
|
| 281 |
+
"auto_map": {
|
| 282 |
+
"AutoConfig": "configuration_ssm_hybrid_apriel15b.AprielSSMHybridConfig",
|
| 283 |
+
"AutoModel": "modeling_ssm_hybrid_apriel15b.AprielThinkerSSMHybridModel",
|
| 284 |
+
"AutoModelForCausalLM": "modeling_ssm_hybrid_apriel15b.AprielThinkerSSMHybridForCausalLM"
|
| 285 |
+
},
|
| 286 |
+
"attn_implementation": null
|
| 287 |
+
},
|
| 288 |
+
"format": "pt"
|
| 289 |
+
},
|
| 290 |
+
"weight_map": {
|
| 291 |
+
"model.embed_tokens.weight": "model_0.safetensors",
|
| 292 |
+
"model.layers.0.input_layernorm.weight": "model_0.safetensors",
|
| 293 |
+
"model.layers.0.post_attention_layernorm.weight": "model_0.safetensors",
|
| 294 |
+
"model.layers.0.self_attn.q_proj.weight": "model_0.safetensors",
|
| 295 |
+
"model.layers.0.self_attn.k_proj.weight": "model_0.safetensors",
|
| 296 |
+
"model.layers.0.self_attn.v_proj.weight": "model_0.safetensors",
|
| 297 |
+
"model.layers.0.self_attn.o_proj.weight": "model_0.safetensors",
|
| 298 |
+
"model.layers.0.mlp.gate_proj.weight": "model_0.safetensors",
|
| 299 |
+
"model.layers.0.mlp.up_proj.weight": "model_0.safetensors",
|
| 300 |
+
"model.layers.0.mlp.down_proj.weight": "model_0.safetensors",
|
| 301 |
+
"model.layers.1.input_layernorm.weight": "model_0.safetensors",
|
| 302 |
+
"model.layers.1.post_attention_layernorm.weight": "model_0.safetensors",
|
| 303 |
+
"model.layers.1.self_attn.q_proj.weight": "model_0.safetensors",
|
| 304 |
+
"model.layers.1.self_attn.k_proj.weight": "model_0.safetensors",
|
| 305 |
+
"model.layers.1.self_attn.v_proj.weight": "model_0.safetensors",
|
| 306 |
+
"model.layers.1.self_attn.o_proj.weight": "model_0.safetensors",
|
| 307 |
+
"model.layers.1.mlp.gate_proj.weight": "model_0.safetensors",
|
| 308 |
+
"model.layers.1.mlp.up_proj.weight": "model_0.safetensors",
|
| 309 |
+
"model.layers.1.mlp.down_proj.weight": "model_0.safetensors",
|
| 310 |
+
"model.layers.2.input_layernorm.weight": "model_0.safetensors",
|
| 311 |
+
"model.layers.2.post_attention_layernorm.weight": "model_0.safetensors",
|
| 312 |
+
"model.layers.2.self_attn.q_proj.weight": "model_0.safetensors",
|
| 313 |
+
"model.layers.2.self_attn.k_proj.weight": "model_0.safetensors",
|
| 314 |
+
"model.layers.2.self_attn.v_proj.weight": "model_0.safetensors",
|
| 315 |
+
"model.layers.2.self_attn.o_proj.weight": "model_0.safetensors",
|
| 316 |
+
"model.layers.2.mlp.gate_proj.weight": "model_0.safetensors",
|
| 317 |
+
"model.layers.2.mlp.up_proj.weight": "model_0.safetensors",
|
| 318 |
+
"model.layers.2.mlp.down_proj.weight": "model_0.safetensors",
|
| 319 |
+
"model.layers.3.mixer.A_log": "model_0.safetensors",
|
| 320 |
+
"model.layers.3.mixer.D": "model_0.safetensors",
|
| 321 |
+
"model.layers.3.input_layernorm.weight": "model_0.safetensors",
|
| 322 |
+
"model.layers.3.post_attention_layernorm.weight": "model_0.safetensors",
|
| 323 |
+
"model.layers.3.mixer.dt_in_proj.weight": "model_0.safetensors",
|
| 324 |
+
"model.layers.3.mixer.conv1d.weight": "model_0.safetensors",
|
| 325 |
+
"model.layers.3.mixer.conv1d.bias": "model_0.safetensors",
|
| 326 |
+
"model.layers.3.mixer.dt_proj.bias": "model_0.safetensors",
|
| 327 |
+
"model.layers.3.mixer.in_proj.weight": "model_0.safetensors",
|
| 328 |
+
"model.layers.3.mixer.dt_proj.weight": "model_0.safetensors",
|
| 329 |
+
"model.layers.3.mixer.out_proj.weight": "model_0.safetensors",
|
| 330 |
+
"model.layers.3.mlp.gate_proj.weight": "model_0.safetensors",
|
| 331 |
+
"model.layers.3.mlp.up_proj.weight": "model_0.safetensors",
|
| 332 |
+
"model.layers.3.mlp.down_proj.weight": "model_0.safetensors",
|
| 333 |
+
"model.layers.4.input_layernorm.weight": "model_0.safetensors",
|
| 334 |
+
"model.layers.4.post_attention_layernorm.weight": "model_0.safetensors",
|
| 335 |
+
"model.layers.4.self_attn.q_proj.weight": "model_0.safetensors",
|
| 336 |
+
"model.layers.4.self_attn.k_proj.weight": "model_0.safetensors",
|
| 337 |
+
"model.layers.4.self_attn.v_proj.weight": "model_0.safetensors",
|
| 338 |
+
"model.layers.4.self_attn.o_proj.weight": "model_0.safetensors",
|
| 339 |
+
"model.layers.4.mlp.gate_proj.weight": "model_0.safetensors",
|
| 340 |
+
"model.layers.4.mlp.up_proj.weight": "model_0.safetensors",
|
| 341 |
+
"model.layers.4.mlp.down_proj.weight": "model_0.safetensors",
|
| 342 |
+
"model.layers.5.mixer.A_log": "model_0.safetensors",
|
| 343 |
+
"model.layers.5.mixer.D": "model_0.safetensors",
|
| 344 |
+
"model.layers.5.input_layernorm.weight": "model_0.safetensors",
|
| 345 |
+
"model.layers.5.post_attention_layernorm.weight": "model_0.safetensors",
|
| 346 |
+
"model.layers.5.mixer.dt_in_proj.weight": "model_0.safetensors",
|
| 347 |
+
"model.layers.5.mixer.conv1d.weight": "model_0.safetensors",
|
| 348 |
+
"model.layers.5.mixer.conv1d.bias": "model_0.safetensors",
|
| 349 |
+
"model.layers.5.mixer.dt_proj.bias": "model_0.safetensors",
|
| 350 |
+
"model.layers.5.mixer.in_proj.weight": "model_0.safetensors",
|
| 351 |
+
"model.layers.5.mixer.dt_proj.weight": "model_0.safetensors",
|
| 352 |
+
"model.layers.5.mixer.out_proj.weight": "model_0.safetensors",
|
| 353 |
+
"model.layers.5.mlp.gate_proj.weight": "model_0.safetensors",
|
| 354 |
+
"model.layers.5.mlp.up_proj.weight": "model_0.safetensors",
|
| 355 |
+
"model.layers.5.mlp.down_proj.weight": "model_0.safetensors",
|
| 356 |
+
"model.layers.6.mixer.A_log": "model_0.safetensors",
|
| 357 |
+
"model.layers.6.mixer.D": "model_0.safetensors",
|
| 358 |
+
"model.layers.6.input_layernorm.weight": "model_0.safetensors",
|
| 359 |
+
"model.layers.6.post_attention_layernorm.weight": "model_0.safetensors",
|
| 360 |
+
"model.layers.6.mixer.dt_in_proj.weight": "model_0.safetensors",
|
| 361 |
+
"model.layers.6.mixer.conv1d.weight": "model_0.safetensors",
|
| 362 |
+
"model.layers.6.mixer.conv1d.bias": "model_0.safetensors",
|
| 363 |
+
"model.layers.6.mixer.dt_proj.bias": "model_0.safetensors",
|
| 364 |
+
"model.layers.6.mixer.in_proj.weight": "model_0.safetensors",
|
| 365 |
+
"model.layers.6.mixer.dt_proj.weight": "model_0.safetensors",
|
| 366 |
+
"model.layers.6.mixer.out_proj.weight": "model_0.safetensors",
|
| 367 |
+
"model.layers.6.mlp.gate_proj.weight": "model_0.safetensors",
|
| 368 |
+
"model.layers.6.mlp.up_proj.weight": "model_0.safetensors",
|
| 369 |
+
"model.layers.6.mlp.down_proj.weight": "model_0.safetensors",
|
| 370 |
+
"model.layers.7.mixer.A_log": "model_0.safetensors",
|
| 371 |
+
"model.layers.7.mixer.D": "model_0.safetensors",
|
| 372 |
+
"model.layers.7.input_layernorm.weight": "model_0.safetensors",
|
| 373 |
+
"model.layers.7.post_attention_layernorm.weight": "model_0.safetensors",
|
| 374 |
+
"model.layers.7.mixer.dt_in_proj.weight": "model_0.safetensors",
|
| 375 |
+
"model.layers.7.mixer.conv1d.weight": "model_0.safetensors",
|
| 376 |
+
"model.layers.7.mixer.conv1d.bias": "model_0.safetensors",
|
| 377 |
+
"model.layers.7.mixer.dt_proj.bias": "model_0.safetensors",
|
| 378 |
+
"model.layers.7.mixer.in_proj.weight": "model_0.safetensors",
|
| 379 |
+
"model.layers.7.mixer.dt_proj.weight": "model_0.safetensors",
|
| 380 |
+
"model.layers.7.mixer.out_proj.weight": "model_0.safetensors",
|
| 381 |
+
"model.layers.7.mlp.gate_proj.weight": "model_0.safetensors",
|
| 382 |
+
"model.layers.7.mlp.up_proj.weight": "model_0.safetensors",
|
| 383 |
+
"model.layers.7.mlp.down_proj.weight": "model_0.safetensors",
|
| 384 |
+
"model.layers.8.mixer.A_log": "model_0.safetensors",
|
| 385 |
+
"model.layers.8.mixer.D": "model_0.safetensors",
|
| 386 |
+
"model.layers.8.input_layernorm.weight": "model_0.safetensors",
|
| 387 |
+
"model.layers.8.post_attention_layernorm.weight": "model_0.safetensors",
|
| 388 |
+
"model.layers.8.mixer.dt_in_proj.weight": "model_0.safetensors",
|
| 389 |
+
"model.layers.8.mixer.conv1d.weight": "model_0.safetensors",
|
| 390 |
+
"model.layers.8.mixer.conv1d.bias": "model_0.safetensors",
|
| 391 |
+
"model.layers.8.mixer.dt_proj.bias": "model_0.safetensors",
|
| 392 |
+
"model.layers.8.mixer.in_proj.weight": "model_0.safetensors",
|
| 393 |
+
"model.layers.8.mixer.dt_proj.weight": "model_0.safetensors",
|
| 394 |
+
"model.layers.8.mixer.out_proj.weight": "model_0.safetensors",
|
| 395 |
+
"model.layers.8.mlp.gate_proj.weight": "model_0.safetensors",
|
| 396 |
+
"model.layers.8.mlp.up_proj.weight": "model_0.safetensors",
|
| 397 |
+
"model.layers.8.mlp.down_proj.weight": "model_0.safetensors",
|
| 398 |
+
"model.layers.9.input_layernorm.weight": "model_0.safetensors",
|
| 399 |
+
"model.layers.9.post_attention_layernorm.weight": "model_0.safetensors",
|
| 400 |
+
"model.layers.9.self_attn.q_proj.weight": "model_0.safetensors",
|
| 401 |
+
"model.layers.9.self_attn.k_proj.weight": "model_0.safetensors",
|
| 402 |
+
"model.layers.9.self_attn.v_proj.weight": "model_0.safetensors",
|
| 403 |
+
"model.layers.9.self_attn.o_proj.weight": "model_0.safetensors",
|
| 404 |
+
"model.layers.9.mlp.gate_proj.weight": "model_0.safetensors",
|
| 405 |
+
"model.layers.9.mlp.up_proj.weight": "model_0.safetensors",
|
| 406 |
+
"model.layers.9.mlp.down_proj.weight": "model_0.safetensors",
|
| 407 |
+
"model.layers.10.input_layernorm.weight": "model_0.safetensors",
|
| 408 |
+
"model.layers.10.post_attention_layernorm.weight": "model_0.safetensors",
|
| 409 |
+
"model.layers.10.self_attn.q_proj.weight": "model_0.safetensors",
|
| 410 |
+
"model.layers.10.self_attn.k_proj.weight": "model_0.safetensors",
|
| 411 |
+
"model.layers.10.self_attn.v_proj.weight": "model_0.safetensors",
|
| 412 |
+
"model.layers.10.self_attn.o_proj.weight": "model_0.safetensors",
|
| 413 |
+
"model.layers.10.mlp.gate_proj.weight": "model_0.safetensors",
|
| 414 |
+
"model.layers.10.mlp.up_proj.weight": "model_0.safetensors",
|
| 415 |
+
"model.layers.10.mlp.down_proj.weight": "model_0.safetensors",
|
| 416 |
+
"model.layers.11.input_layernorm.weight": "model_0.safetensors",
|
| 417 |
+
"model.layers.11.post_attention_layernorm.weight": "model_0.safetensors",
|
| 418 |
+
"model.layers.11.self_attn.q_proj.weight": "model_0.safetensors",
|
| 419 |
+
"model.layers.11.self_attn.k_proj.weight": "model_0.safetensors",
|
| 420 |
+
"model.layers.11.self_attn.v_proj.weight": "model_0.safetensors",
|
| 421 |
+
"model.layers.11.self_attn.o_proj.weight": "model_0.safetensors",
|
| 422 |
+
"model.layers.11.mlp.gate_proj.weight": "model_0.safetensors",
|
| 423 |
+
"model.layers.11.mlp.up_proj.weight": "model_0.safetensors",
|
| 424 |
+
"model.layers.11.mlp.down_proj.weight": "model_0.safetensors",
|
| 425 |
+
"model.layers.12.input_layernorm.weight": "model_0.safetensors",
|
| 426 |
+
"model.layers.12.post_attention_layernorm.weight": "model_0.safetensors",
|
| 427 |
+
"model.layers.12.self_attn.q_proj.weight": "model_0.safetensors",
|
| 428 |
+
"model.layers.12.self_attn.k_proj.weight": "model_0.safetensors",
|
| 429 |
+
"model.layers.12.self_attn.v_proj.weight": "model_0.safetensors",
|
| 430 |
+
"model.layers.12.self_attn.o_proj.weight": "model_0.safetensors",
|
| 431 |
+
"model.layers.12.mlp.gate_proj.weight": "model_0.safetensors",
|
| 432 |
+
"model.layers.12.mlp.up_proj.weight": "model_0.safetensors",
|
| 433 |
+
"model.layers.12.mlp.down_proj.weight": "model_0.safetensors",
|
| 434 |
+
"model.layers.13.input_layernorm.weight": "model_1.safetensors",
|
| 435 |
+
"model.layers.13.post_attention_layernorm.weight": "model_1.safetensors",
|
| 436 |
+
"model.layers.13.self_attn.q_proj.weight": "model_1.safetensors",
|
| 437 |
+
"model.layers.13.self_attn.k_proj.weight": "model_1.safetensors",
|
| 438 |
+
"model.layers.13.self_attn.v_proj.weight": "model_1.safetensors",
|
| 439 |
+
"model.layers.13.self_attn.o_proj.weight": "model_1.safetensors",
|
| 440 |
+
"model.layers.13.mlp.gate_proj.weight": "model_1.safetensors",
|
| 441 |
+
"model.layers.13.mlp.up_proj.weight": "model_1.safetensors",
|
| 442 |
+
"model.layers.13.mlp.down_proj.weight": "model_1.safetensors",
|
| 443 |
+
"model.layers.14.input_layernorm.weight": "model_1.safetensors",
|
| 444 |
+
"model.layers.14.post_attention_layernorm.weight": "model_1.safetensors",
|
| 445 |
+
"model.layers.14.self_attn.q_proj.weight": "model_1.safetensors",
|
| 446 |
+
"model.layers.14.self_attn.k_proj.weight": "model_1.safetensors",
|
| 447 |
+
"model.layers.14.self_attn.v_proj.weight": "model_1.safetensors",
|
| 448 |
+
"model.layers.14.self_attn.o_proj.weight": "model_1.safetensors",
|
| 449 |
+
"model.layers.14.mlp.gate_proj.weight": "model_1.safetensors",
|
| 450 |
+
"model.layers.14.mlp.up_proj.weight": "model_1.safetensors",
|
| 451 |
+
"model.layers.14.mlp.down_proj.weight": "model_1.safetensors",
|
| 452 |
+
"model.layers.15.input_layernorm.weight": "model_1.safetensors",
|
| 453 |
+
"model.layers.15.post_attention_layernorm.weight": "model_1.safetensors",
|
| 454 |
+
"model.layers.15.self_attn.q_proj.weight": "model_1.safetensors",
|
| 455 |
+
"model.layers.15.self_attn.k_proj.weight": "model_1.safetensors",
|
| 456 |
+
"model.layers.15.self_attn.v_proj.weight": "model_1.safetensors",
|
| 457 |
+
"model.layers.15.self_attn.o_proj.weight": "model_1.safetensors",
|
| 458 |
+
"model.layers.15.mlp.gate_proj.weight": "model_1.safetensors",
|
| 459 |
+
"model.layers.15.mlp.up_proj.weight": "model_1.safetensors",
|
| 460 |
+
"model.layers.15.mlp.down_proj.weight": "model_1.safetensors",
|
| 461 |
+
"model.layers.16.mixer.A_log": "model_1.safetensors",
|
| 462 |
+
"model.layers.16.mixer.D": "model_1.safetensors",
|
| 463 |
+
"model.layers.16.input_layernorm.weight": "model_1.safetensors",
|
| 464 |
+
"model.layers.16.post_attention_layernorm.weight": "model_1.safetensors",
|
| 465 |
+
"model.layers.16.mixer.dt_in_proj.weight": "model_1.safetensors",
|
| 466 |
+
"model.layers.16.mixer.conv1d.weight": "model_1.safetensors",
|
| 467 |
+
"model.layers.16.mixer.conv1d.bias": "model_1.safetensors",
|
| 468 |
+
"model.layers.16.mixer.dt_proj.bias": "model_1.safetensors",
|
| 469 |
+
"model.layers.16.mixer.in_proj.weight": "model_1.safetensors",
|
| 470 |
+
"model.layers.16.mixer.dt_proj.weight": "model_1.safetensors",
|
| 471 |
+
"model.layers.16.mixer.out_proj.weight": "model_1.safetensors",
|
| 472 |
+
"model.layers.16.mlp.gate_proj.weight": "model_1.safetensors",
|
| 473 |
+
"model.layers.16.mlp.up_proj.weight": "model_1.safetensors",
|
| 474 |
+
"model.layers.16.mlp.down_proj.weight": "model_1.safetensors",
|
| 475 |
+
"model.layers.17.input_layernorm.weight": "model_1.safetensors",
|
| 476 |
+
"model.layers.17.post_attention_layernorm.weight": "model_1.safetensors",
|
| 477 |
+
"model.layers.17.self_attn.q_proj.weight": "model_1.safetensors",
|
| 478 |
+
"model.layers.17.self_attn.k_proj.weight": "model_1.safetensors",
|
| 479 |
+
"model.layers.17.self_attn.v_proj.weight": "model_1.safetensors",
|
| 480 |
+
"model.layers.17.self_attn.o_proj.weight": "model_1.safetensors",
|
| 481 |
+
"model.layers.17.mlp.gate_proj.weight": "model_1.safetensors",
|
| 482 |
+
"model.layers.17.mlp.up_proj.weight": "model_1.safetensors",
|
| 483 |
+
"model.layers.17.mlp.down_proj.weight": "model_1.safetensors",
|
| 484 |
+
"model.layers.18.mixer.A_log": "model_1.safetensors",
|
| 485 |
+
"model.layers.18.mixer.D": "model_1.safetensors",
|
| 486 |
+
"model.layers.18.input_layernorm.weight": "model_1.safetensors",
|
| 487 |
+
"model.layers.18.post_attention_layernorm.weight": "model_1.safetensors",
|
| 488 |
+
"model.layers.18.mixer.dt_in_proj.weight": "model_1.safetensors",
|
| 489 |
+
"model.layers.18.mixer.conv1d.weight": "model_1.safetensors",
|
| 490 |
+
"model.layers.18.mixer.conv1d.bias": "model_1.safetensors",
|
| 491 |
+
"model.layers.18.mixer.dt_proj.bias": "model_1.safetensors",
|
| 492 |
+
"model.layers.18.mixer.in_proj.weight": "model_1.safetensors",
|
| 493 |
+
"model.layers.18.mixer.dt_proj.weight": "model_1.safetensors",
|
| 494 |
+
"model.layers.18.mixer.out_proj.weight": "model_1.safetensors",
|
| 495 |
+
"model.layers.18.mlp.gate_proj.weight": "model_1.safetensors",
|
| 496 |
+
"model.layers.18.mlp.up_proj.weight": "model_1.safetensors",
|
| 497 |
+
"model.layers.18.mlp.down_proj.weight": "model_1.safetensors",
|
| 498 |
+
"model.layers.19.mixer.A_log": "model_1.safetensors",
|
| 499 |
+
"model.layers.19.mixer.D": "model_1.safetensors",
|
| 500 |
+
"model.layers.19.input_layernorm.weight": "model_1.safetensors",
|
| 501 |
+
"model.layers.19.post_attention_layernorm.weight": "model_1.safetensors",
|
| 502 |
+
"model.layers.19.mixer.dt_in_proj.weight": "model_1.safetensors",
|
| 503 |
+
"model.layers.19.mixer.conv1d.weight": "model_1.safetensors",
|
| 504 |
+
"model.layers.19.mixer.conv1d.bias": "model_1.safetensors",
|
| 505 |
+
"model.layers.19.mixer.dt_proj.bias": "model_1.safetensors",
|
| 506 |
+
"model.layers.19.mixer.in_proj.weight": "model_1.safetensors",
|
| 507 |
+
"model.layers.19.mixer.dt_proj.weight": "model_1.safetensors",
|
| 508 |
+
"model.layers.19.mixer.out_proj.weight": "model_1.safetensors",
|
| 509 |
+
"model.layers.19.mlp.gate_proj.weight": "model_1.safetensors",
|
| 510 |
+
"model.layers.19.mlp.up_proj.weight": "model_1.safetensors",
|
| 511 |
+
"model.layers.19.mlp.down_proj.weight": "model_1.safetensors",
|
| 512 |
+
"model.layers.20.mixer.A_log": "model_1.safetensors",
|
| 513 |
+
"model.layers.20.mixer.D": "model_1.safetensors",
|
| 514 |
+
"model.layers.20.input_layernorm.weight": "model_1.safetensors",
|
| 515 |
+
"model.layers.20.post_attention_layernorm.weight": "model_1.safetensors",
|
| 516 |
+
"model.layers.20.mixer.dt_in_proj.weight": "model_1.safetensors",
|
| 517 |
+
"model.layers.20.mixer.conv1d.weight": "model_1.safetensors",
|
| 518 |
+
"model.layers.20.mixer.conv1d.bias": "model_1.safetensors",
|
| 519 |
+
"model.layers.20.mixer.dt_proj.bias": "model_1.safetensors",
|
| 520 |
+
"model.layers.20.mixer.in_proj.weight": "model_1.safetensors",
|
| 521 |
+
"model.layers.20.mixer.dt_proj.weight": "model_1.safetensors",
|
| 522 |
+
"model.layers.20.mixer.out_proj.weight": "model_1.safetensors",
|
| 523 |
+
"model.layers.20.mlp.gate_proj.weight": "model_1.safetensors",
|
| 524 |
+
"model.layers.20.mlp.up_proj.weight": "model_1.safetensors",
|
| 525 |
+
"model.layers.20.mlp.down_proj.weight": "model_1.safetensors",
|
| 526 |
+
"model.layers.21.mixer.A_log": "model_1.safetensors",
|
| 527 |
+
"model.layers.21.mixer.D": "model_1.safetensors",
|
| 528 |
+
"model.layers.21.input_layernorm.weight": "model_1.safetensors",
|
| 529 |
+
"model.layers.21.post_attention_layernorm.weight": "model_1.safetensors",
|
| 530 |
+
"model.layers.21.mixer.dt_in_proj.weight": "model_1.safetensors",
|
| 531 |
+
"model.layers.21.mixer.conv1d.weight": "model_1.safetensors",
|
| 532 |
+
"model.layers.21.mixer.conv1d.bias": "model_1.safetensors",
|
| 533 |
+
"model.layers.21.mixer.dt_proj.bias": "model_1.safetensors",
|
| 534 |
+
"model.layers.21.mixer.in_proj.weight": "model_1.safetensors",
|
| 535 |
+
"model.layers.21.mixer.dt_proj.weight": "model_1.safetensors",
|
| 536 |
+
"model.layers.21.mixer.out_proj.weight": "model_1.safetensors",
|
| 537 |
+
"model.layers.21.mlp.gate_proj.weight": "model_1.safetensors",
|
| 538 |
+
"model.layers.21.mlp.up_proj.weight": "model_1.safetensors",
|
| 539 |
+
"model.layers.21.mlp.down_proj.weight": "model_1.safetensors",
|
| 540 |
+
"model.layers.22.mixer.A_log": "model_1.safetensors",
|
| 541 |
+
"model.layers.22.mixer.D": "model_1.safetensors",
|
| 542 |
+
"model.layers.22.input_layernorm.weight": "model_1.safetensors",
|
| 543 |
+
"model.layers.22.post_attention_layernorm.weight": "model_1.safetensors",
|
| 544 |
+
"model.layers.22.mixer.dt_in_proj.weight": "model_1.safetensors",
|
| 545 |
+
"model.layers.22.mixer.conv1d.weight": "model_1.safetensors",
|
| 546 |
+
"model.layers.22.mixer.conv1d.bias": "model_1.safetensors",
|
| 547 |
+
"model.layers.22.mixer.dt_proj.bias": "model_1.safetensors",
|
| 548 |
+
"model.layers.22.mixer.in_proj.weight": "model_1.safetensors",
|
| 549 |
+
"model.layers.22.mixer.dt_proj.weight": "model_1.safetensors",
|
| 550 |
+
"model.layers.22.mixer.out_proj.weight": "model_1.safetensors",
|
| 551 |
+
"model.layers.22.mlp.gate_proj.weight": "model_1.safetensors",
|
| 552 |
+
"model.layers.22.mlp.up_proj.weight": "model_1.safetensors",
|
| 553 |
+
"model.layers.22.mlp.down_proj.weight": "model_1.safetensors",
|
| 554 |
+
"model.layers.23.mixer.A_log": "model_1.safetensors",
|
| 555 |
+
"model.layers.23.mixer.D": "model_1.safetensors",
|
| 556 |
+
"model.layers.23.input_layernorm.weight": "model_1.safetensors",
|
| 557 |
+
"model.layers.23.post_attention_layernorm.weight": "model_1.safetensors",
|
| 558 |
+
"model.layers.23.mixer.dt_in_proj.weight": "model_1.safetensors",
|
| 559 |
+
"model.layers.23.mixer.conv1d.weight": "model_1.safetensors",
|
| 560 |
+
"model.layers.23.mixer.conv1d.bias": "model_1.safetensors",
|
| 561 |
+
"model.layers.23.mixer.dt_proj.bias": "model_1.safetensors",
|
| 562 |
+
"model.layers.23.mixer.in_proj.weight": "model_1.safetensors",
|
| 563 |
+
"model.layers.23.mixer.dt_proj.weight": "model_1.safetensors",
|
| 564 |
+
"model.layers.23.mixer.out_proj.weight": "model_1.safetensors",
|
| 565 |
+
"model.layers.23.mlp.gate_proj.weight": "model_1.safetensors",
|
| 566 |
+
"model.layers.23.mlp.up_proj.weight": "model_1.safetensors",
|
| 567 |
+
"model.layers.23.mlp.down_proj.weight": "model_1.safetensors",
|
| 568 |
+
"model.layers.24.mixer.A_log": "model_1.safetensors",
|
| 569 |
+
"model.layers.24.mixer.D": "model_1.safetensors",
|
| 570 |
+
"model.layers.24.input_layernorm.weight": "model_1.safetensors",
|
| 571 |
+
"model.layers.24.post_attention_layernorm.weight": "model_1.safetensors",
|
| 572 |
+
"model.layers.24.mixer.dt_in_proj.weight": "model_1.safetensors",
|
| 573 |
+
"model.layers.24.mixer.conv1d.weight": "model_1.safetensors",
|
| 574 |
+
"model.layers.24.mixer.conv1d.bias": "model_1.safetensors",
|
| 575 |
+
"model.layers.24.mixer.dt_proj.bias": "model_1.safetensors",
|
| 576 |
+
"model.layers.24.mixer.in_proj.weight": "model_1.safetensors",
|
| 577 |
+
"model.layers.24.mixer.dt_proj.weight": "model_1.safetensors",
|
| 578 |
+
"model.layers.24.mixer.out_proj.weight": "model_1.safetensors",
|
| 579 |
+
"model.layers.24.mlp.gate_proj.weight": "model_1.safetensors",
|
| 580 |
+
"model.layers.24.mlp.up_proj.weight": "model_1.safetensors",
|
| 581 |
+
"model.layers.24.mlp.down_proj.weight": "model_1.safetensors",
|
| 582 |
+
"model.layers.25.mixer.A_log": "model_1.safetensors",
|
| 583 |
+
"model.layers.25.mixer.D": "model_1.safetensors",
|
| 584 |
+
"model.layers.25.input_layernorm.weight": "model_1.safetensors",
|
| 585 |
+
"model.layers.25.post_attention_layernorm.weight": "model_1.safetensors",
|
| 586 |
+
"model.layers.25.mixer.dt_in_proj.weight": "model_1.safetensors",
|
| 587 |
+
"model.layers.25.mixer.conv1d.weight": "model_1.safetensors",
|
| 588 |
+
"model.layers.25.mixer.conv1d.bias": "model_1.safetensors",
|
| 589 |
+
"model.layers.25.mixer.dt_proj.bias": "model_1.safetensors",
|
| 590 |
+
"model.layers.25.mixer.in_proj.weight": "model_1.safetensors",
|
| 591 |
+
"model.layers.25.mixer.dt_proj.weight": "model_1.safetensors",
|
| 592 |
+
"model.layers.25.mixer.out_proj.weight": "model_1.safetensors",
|
| 593 |
+
"model.layers.25.mlp.gate_proj.weight": "model_1.safetensors",
|
| 594 |
+
"model.layers.25.mlp.up_proj.weight": "model_1.safetensors",
|
| 595 |
+
"model.layers.25.mlp.down_proj.weight": "model_1.safetensors",
|
| 596 |
+
"model.layers.26.mixer.A_log": "model_1.safetensors",
|
| 597 |
+
"model.layers.26.mixer.D": "model_1.safetensors",
|
| 598 |
+
"model.layers.26.input_layernorm.weight": "model_1.safetensors",
|
| 599 |
+
"model.layers.26.post_attention_layernorm.weight": "model_1.safetensors",
|
| 600 |
+
"model.layers.26.mixer.dt_in_proj.weight": "model_1.safetensors",
|
| 601 |
+
"model.layers.26.mixer.conv1d.weight": "model_1.safetensors",
|
| 602 |
+
"model.layers.26.mixer.conv1d.bias": "model_1.safetensors",
|
| 603 |
+
"model.layers.26.mixer.dt_proj.bias": "model_1.safetensors",
|
| 604 |
+
"model.layers.26.mixer.in_proj.weight": "model_1.safetensors",
|
| 605 |
+
"model.layers.26.mixer.dt_proj.weight": "model_1.safetensors",
|
| 606 |
+
"model.layers.26.mixer.out_proj.weight": "model_1.safetensors",
|
| 607 |
+
"model.layers.26.mlp.gate_proj.weight": "model_1.safetensors",
|
| 608 |
+
"model.layers.26.mlp.up_proj.weight": "model_1.safetensors",
|
| 609 |
+
"model.layers.26.mlp.down_proj.weight": "model_1.safetensors",
|
| 610 |
+
"model.layers.27.mixer.A_log": "model_1.safetensors",
|
| 611 |
+
"model.layers.27.mixer.D": "model_1.safetensors",
|
| 612 |
+
"model.layers.27.input_layernorm.weight": "model_1.safetensors",
|
| 613 |
+
"model.layers.27.post_attention_layernorm.weight": "model_1.safetensors",
|
| 614 |
+
"model.layers.27.mixer.dt_in_proj.weight": "model_1.safetensors",
|
| 615 |
+
"model.layers.27.mixer.conv1d.weight": "model_1.safetensors",
|
| 616 |
+
"model.layers.27.mixer.conv1d.bias": "model_1.safetensors",
|
| 617 |
+
"model.layers.27.mixer.dt_proj.bias": "model_1.safetensors",
|
| 618 |
+
"model.layers.27.mixer.in_proj.weight": "model_1.safetensors",
|
| 619 |
+
"model.layers.27.mixer.dt_proj.weight": "model_1.safetensors",
|
| 620 |
+
"model.layers.27.mixer.out_proj.weight": "model_1.safetensors",
|
| 621 |
+
"model.layers.27.mlp.gate_proj.weight": "model_1.safetensors",
|
| 622 |
+
"model.layers.27.mlp.up_proj.weight": "model_1.safetensors",
|
| 623 |
+
"model.layers.27.mlp.down_proj.weight": "model_1.safetensors",
|
| 624 |
+
"model.layers.28.input_layernorm.weight": "model_2.safetensors",
|
| 625 |
+
"model.layers.28.post_attention_layernorm.weight": "model_2.safetensors",
|
| 626 |
+
"model.layers.28.self_attn.q_proj.weight": "model_2.safetensors",
|
| 627 |
+
"model.layers.28.self_attn.k_proj.weight": "model_2.safetensors",
|
| 628 |
+
"model.layers.28.self_attn.v_proj.weight": "model_2.safetensors",
|
| 629 |
+
"model.layers.28.self_attn.o_proj.weight": "model_2.safetensors",
|
| 630 |
+
"model.layers.28.mlp.gate_proj.weight": "model_2.safetensors",
|
| 631 |
+
"model.layers.28.mlp.up_proj.weight": "model_2.safetensors",
|
| 632 |
+
"model.layers.28.mlp.down_proj.weight": "model_2.safetensors",
|
| 633 |
+
"model.layers.29.input_layernorm.weight": "model_2.safetensors",
|
| 634 |
+
"model.layers.29.post_attention_layernorm.weight": "model_2.safetensors",
|
| 635 |
+
"model.layers.29.self_attn.q_proj.weight": "model_2.safetensors",
|
| 636 |
+
"model.layers.29.self_attn.k_proj.weight": "model_2.safetensors",
|
| 637 |
+
"model.layers.29.self_attn.v_proj.weight": "model_2.safetensors",
|
| 638 |
+
"model.layers.29.self_attn.o_proj.weight": "model_2.safetensors",
|
| 639 |
+
"model.layers.29.mlp.gate_proj.weight": "model_2.safetensors",
|
| 640 |
+
"model.layers.29.mlp.up_proj.weight": "model_2.safetensors",
|
| 641 |
+
"model.layers.29.mlp.down_proj.weight": "model_2.safetensors",
|
| 642 |
+
"model.layers.30.mixer.A_log": "model_2.safetensors",
|
| 643 |
+
"model.layers.30.mixer.D": "model_2.safetensors",
|
| 644 |
+
"model.layers.30.input_layernorm.weight": "model_2.safetensors",
|
| 645 |
+
"model.layers.30.post_attention_layernorm.weight": "model_2.safetensors",
|
| 646 |
+
"model.layers.30.mixer.dt_in_proj.weight": "model_2.safetensors",
|
| 647 |
+
"model.layers.30.mixer.conv1d.weight": "model_2.safetensors",
|
| 648 |
+
"model.layers.30.mixer.conv1d.bias": "model_2.safetensors",
|
| 649 |
+
"model.layers.30.mixer.dt_proj.bias": "model_2.safetensors",
|
| 650 |
+
"model.layers.30.mixer.in_proj.weight": "model_2.safetensors",
|
| 651 |
+
"model.layers.30.mixer.dt_proj.weight": "model_2.safetensors",
|
| 652 |
+
"model.layers.30.mixer.out_proj.weight": "model_2.safetensors",
|
| 653 |
+
"model.layers.30.mlp.gate_proj.weight": "model_2.safetensors",
|
| 654 |
+
"model.layers.30.mlp.up_proj.weight": "model_2.safetensors",
|
| 655 |
+
"model.layers.30.mlp.down_proj.weight": "model_2.safetensors",
|
| 656 |
+
"model.layers.31.mixer.A_log": "model_2.safetensors",
|
| 657 |
+
"model.layers.31.mixer.D": "model_2.safetensors",
|
| 658 |
+
"model.layers.31.input_layernorm.weight": "model_2.safetensors",
|
| 659 |
+
"model.layers.31.post_attention_layernorm.weight": "model_2.safetensors",
|
| 660 |
+
"model.layers.31.mixer.dt_in_proj.weight": "model_2.safetensors",
|
| 661 |
+
"model.layers.31.mixer.conv1d.weight": "model_2.safetensors",
|
| 662 |
+
"model.layers.31.mixer.conv1d.bias": "model_2.safetensors",
|
| 663 |
+
"model.layers.31.mixer.dt_proj.bias": "model_2.safetensors",
|
| 664 |
+
"model.layers.31.mixer.in_proj.weight": "model_2.safetensors",
|
| 665 |
+
"model.layers.31.mixer.dt_proj.weight": "model_2.safetensors",
|
| 666 |
+
"model.layers.31.mixer.out_proj.weight": "model_2.safetensors",
|
| 667 |
+
"model.layers.31.mlp.gate_proj.weight": "model_2.safetensors",
|
| 668 |
+
"model.layers.31.mlp.up_proj.weight": "model_2.safetensors",
|
| 669 |
+
"model.layers.31.mlp.down_proj.weight": "model_2.safetensors",
|
| 670 |
+
"model.layers.32.mixer.A_log": "model_2.safetensors",
|
| 671 |
+
"model.layers.32.mixer.D": "model_2.safetensors",
|
| 672 |
+
"model.layers.32.input_layernorm.weight": "model_2.safetensors",
|
| 673 |
+
"model.layers.32.post_attention_layernorm.weight": "model_2.safetensors",
|
| 674 |
+
"model.layers.32.mixer.dt_in_proj.weight": "model_2.safetensors",
|
| 675 |
+
"model.layers.32.mixer.conv1d.weight": "model_2.safetensors",
|
| 676 |
+
"model.layers.32.mixer.conv1d.bias": "model_2.safetensors",
|
| 677 |
+
"model.layers.32.mixer.dt_proj.bias": "model_2.safetensors",
|
| 678 |
+
"model.layers.32.mixer.in_proj.weight": "model_2.safetensors",
|
| 679 |
+
"model.layers.32.mixer.dt_proj.weight": "model_2.safetensors",
|
| 680 |
+
"model.layers.32.mixer.out_proj.weight": "model_2.safetensors",
|
| 681 |
+
"model.layers.32.mlp.gate_proj.weight": "model_2.safetensors",
|
| 682 |
+
"model.layers.32.mlp.up_proj.weight": "model_2.safetensors",
|
| 683 |
+
"model.layers.32.mlp.down_proj.weight": "model_2.safetensors",
|
| 684 |
+
"model.layers.33.mixer.A_log": "model_2.safetensors",
|
| 685 |
+
"model.layers.33.mixer.D": "model_2.safetensors",
|
| 686 |
+
"model.layers.33.input_layernorm.weight": "model_2.safetensors",
|
| 687 |
+
"model.layers.33.post_attention_layernorm.weight": "model_2.safetensors",
|
| 688 |
+
"model.layers.33.mixer.dt_in_proj.weight": "model_2.safetensors",
|
| 689 |
+
"model.layers.33.mixer.conv1d.weight": "model_2.safetensors",
|
| 690 |
+
"model.layers.33.mixer.conv1d.bias": "model_2.safetensors",
|
| 691 |
+
"model.layers.33.mixer.dt_proj.bias": "model_2.safetensors",
|
| 692 |
+
"model.layers.33.mixer.in_proj.weight": "model_2.safetensors",
|
| 693 |
+
"model.layers.33.mixer.dt_proj.weight": "model_2.safetensors",
|
| 694 |
+
"model.layers.33.mixer.out_proj.weight": "model_2.safetensors",
|
| 695 |
+
"model.layers.33.mlp.gate_proj.weight": "model_2.safetensors",
|
| 696 |
+
"model.layers.33.mlp.up_proj.weight": "model_2.safetensors",
|
| 697 |
+
"model.layers.33.mlp.down_proj.weight": "model_2.safetensors",
|
| 698 |
+
"model.layers.34.input_layernorm.weight": "model_2.safetensors",
|
| 699 |
+
"model.layers.34.post_attention_layernorm.weight": "model_2.safetensors",
|
| 700 |
+
"model.layers.34.self_attn.q_proj.weight": "model_2.safetensors",
|
| 701 |
+
"model.layers.34.self_attn.k_proj.weight": "model_2.safetensors",
|
| 702 |
+
"model.layers.34.self_attn.v_proj.weight": "model_2.safetensors",
|
| 703 |
+
"model.layers.34.self_attn.o_proj.weight": "model_2.safetensors",
|
| 704 |
+
"model.layers.34.mlp.gate_proj.weight": "model_2.safetensors",
|
| 705 |
+
"model.layers.34.mlp.up_proj.weight": "model_2.safetensors",
|
| 706 |
+
"model.layers.34.mlp.down_proj.weight": "model_2.safetensors",
|
| 707 |
+
"model.layers.35.mixer.A_log": "model_2.safetensors",
|
| 708 |
+
"model.layers.35.mixer.D": "model_2.safetensors",
|
| 709 |
+
"model.layers.35.input_layernorm.weight": "model_2.safetensors",
|
| 710 |
+
"model.layers.35.post_attention_layernorm.weight": "model_2.safetensors",
|
| 711 |
+
"model.layers.35.mixer.dt_in_proj.weight": "model_2.safetensors",
|
| 712 |
+
"model.layers.35.mixer.conv1d.weight": "model_2.safetensors",
|
| 713 |
+
"model.layers.35.mixer.conv1d.bias": "model_2.safetensors",
|
| 714 |
+
"model.layers.35.mixer.dt_proj.bias": "model_2.safetensors",
|
| 715 |
+
"model.layers.35.mixer.in_proj.weight": "model_2.safetensors",
|
| 716 |
+
"model.layers.35.mixer.dt_proj.weight": "model_2.safetensors",
|
| 717 |
+
"model.layers.35.mixer.out_proj.weight": "model_2.safetensors",
|
| 718 |
+
"model.layers.35.mlp.gate_proj.weight": "model_2.safetensors",
|
| 719 |
+
"model.layers.35.mlp.up_proj.weight": "model_2.safetensors",
|
| 720 |
+
"model.layers.35.mlp.down_proj.weight": "model_2.safetensors",
|
| 721 |
+
"model.layers.36.mixer.A_log": "model_2.safetensors",
|
| 722 |
+
"model.layers.36.mixer.D": "model_2.safetensors",
|
| 723 |
+
"model.layers.36.input_layernorm.weight": "model_2.safetensors",
|
| 724 |
+
"model.layers.36.post_attention_layernorm.weight": "model_2.safetensors",
|
| 725 |
+
"model.layers.36.mixer.dt_in_proj.weight": "model_2.safetensors",
|
| 726 |
+
"model.layers.36.mixer.conv1d.weight": "model_2.safetensors",
|
| 727 |
+
"model.layers.36.mixer.conv1d.bias": "model_2.safetensors",
|
| 728 |
+
"model.layers.36.mixer.dt_proj.bias": "model_2.safetensors",
|
| 729 |
+
"model.layers.36.mixer.in_proj.weight": "model_2.safetensors",
|
| 730 |
+
"model.layers.36.mixer.dt_proj.weight": "model_2.safetensors",
|
| 731 |
+
"model.layers.36.mixer.out_proj.weight": "model_2.safetensors",
|
| 732 |
+
"model.layers.36.mlp.gate_proj.weight": "model_2.safetensors",
|
| 733 |
+
"model.layers.36.mlp.up_proj.weight": "model_2.safetensors",
|
| 734 |
+
"model.layers.36.mlp.down_proj.weight": "model_2.safetensors",
|
| 735 |
+
"model.layers.37.mixer.A_log": "model_2.safetensors",
|
| 736 |
+
"model.layers.37.mixer.D": "model_2.safetensors",
|
| 737 |
+
"model.layers.37.input_layernorm.weight": "model_2.safetensors",
|
| 738 |
+
"model.layers.37.post_attention_layernorm.weight": "model_2.safetensors",
|
| 739 |
+
"model.layers.37.mixer.dt_in_proj.weight": "model_2.safetensors",
|
| 740 |
+
"model.layers.37.mixer.conv1d.weight": "model_2.safetensors",
|
| 741 |
+
"model.layers.37.mixer.conv1d.bias": "model_2.safetensors",
|
| 742 |
+
"model.layers.37.mixer.dt_proj.bias": "model_2.safetensors",
|
| 743 |
+
"model.layers.37.mixer.in_proj.weight": "model_2.safetensors",
|
| 744 |
+
"model.layers.37.mixer.dt_proj.weight": "model_2.safetensors",
|
| 745 |
+
"model.layers.37.mixer.out_proj.weight": "model_2.safetensors",
|
| 746 |
+
"model.layers.37.mlp.gate_proj.weight": "model_2.safetensors",
|
| 747 |
+
"model.layers.37.mlp.up_proj.weight": "model_2.safetensors",
|
| 748 |
+
"model.layers.37.mlp.down_proj.weight": "model_2.safetensors",
|
| 749 |
+
"model.layers.38.mixer.A_log": "model_2.safetensors",
|
| 750 |
+
"model.layers.38.mixer.D": "model_2.safetensors",
|
| 751 |
+
"model.layers.38.input_layernorm.weight": "model_2.safetensors",
|
| 752 |
+
"model.layers.38.post_attention_layernorm.weight": "model_2.safetensors",
|
| 753 |
+
"model.layers.38.mixer.dt_in_proj.weight": "model_2.safetensors",
|
| 754 |
+
"model.layers.38.mixer.conv1d.weight": "model_2.safetensors",
|
| 755 |
+
"model.layers.38.mixer.conv1d.bias": "model_2.safetensors",
|
| 756 |
+
"model.layers.38.mixer.dt_proj.bias": "model_2.safetensors",
|
| 757 |
+
"model.layers.38.mixer.in_proj.weight": "model_2.safetensors",
|
| 758 |
+
"model.layers.38.mixer.dt_proj.weight": "model_2.safetensors",
|
| 759 |
+
"model.layers.38.mixer.out_proj.weight": "model_2.safetensors",
|
| 760 |
+
"model.layers.38.mlp.gate_proj.weight": "model_2.safetensors",
|
| 761 |
+
"model.layers.38.mlp.up_proj.weight": "model_2.safetensors",
|
| 762 |
+
"model.layers.38.mlp.down_proj.weight": "model_2.safetensors",
|
| 763 |
+
"model.layers.39.mixer.A_log": "model_2.safetensors",
|
| 764 |
+
"model.layers.39.mixer.D": "model_2.safetensors",
|
| 765 |
+
"model.layers.39.input_layernorm.weight": "model_2.safetensors",
|
| 766 |
+
"model.layers.39.post_attention_layernorm.weight": "model_2.safetensors",
|
| 767 |
+
"model.layers.39.mixer.dt_in_proj.weight": "model_2.safetensors",
|
| 768 |
+
"model.layers.39.mixer.conv1d.weight": "model_2.safetensors",
|
| 769 |
+
"model.layers.39.mixer.conv1d.bias": "model_2.safetensors",
|
| 770 |
+
"model.layers.39.mixer.dt_proj.bias": "model_2.safetensors",
|
| 771 |
+
"model.layers.39.mixer.in_proj.weight": "model_2.safetensors",
|
| 772 |
+
"model.layers.39.mixer.dt_proj.weight": "model_2.safetensors",
|
| 773 |
+
"model.layers.39.mixer.out_proj.weight": "model_2.safetensors",
|
| 774 |
+
"model.layers.39.mlp.gate_proj.weight": "model_2.safetensors",
|
| 775 |
+
"model.layers.39.mlp.up_proj.weight": "model_2.safetensors",
|
| 776 |
+
"model.layers.39.mlp.down_proj.weight": "model_2.safetensors",
|
| 777 |
+
"model.layers.40.mixer.A_log": "model_2.safetensors",
|
| 778 |
+
"model.layers.40.mixer.D": "model_2.safetensors",
|
| 779 |
+
"model.layers.40.input_layernorm.weight": "model_2.safetensors",
|
| 780 |
+
"model.layers.40.post_attention_layernorm.weight": "model_2.safetensors",
|
| 781 |
+
"model.layers.40.mixer.dt_in_proj.weight": "model_2.safetensors",
|
| 782 |
+
"model.layers.40.mixer.conv1d.weight": "model_2.safetensors",
|
| 783 |
+
"model.layers.40.mixer.conv1d.bias": "model_2.safetensors",
|
| 784 |
+
"model.layers.40.mixer.dt_proj.bias": "model_2.safetensors",
|
| 785 |
+
"model.layers.40.mixer.in_proj.weight": "model_2.safetensors",
|
| 786 |
+
"model.layers.40.mixer.dt_proj.weight": "model_2.safetensors",
|
| 787 |
+
"model.layers.40.mixer.out_proj.weight": "model_2.safetensors",
|
| 788 |
+
"model.layers.40.mlp.gate_proj.weight": "model_2.safetensors",
|
| 789 |
+
"model.layers.40.mlp.up_proj.weight": "model_2.safetensors",
|
| 790 |
+
"model.layers.40.mlp.down_proj.weight": "model_2.safetensors",
|
| 791 |
+
"model.layers.41.mixer.A_log": "model_2.safetensors",
|
| 792 |
+
"model.layers.41.mixer.D": "model_2.safetensors",
|
| 793 |
+
"model.layers.41.input_layernorm.weight": "model_2.safetensors",
|
| 794 |
+
"model.layers.41.post_attention_layernorm.weight": "model_2.safetensors",
|
| 795 |
+
"model.layers.41.mixer.dt_in_proj.weight": "model_2.safetensors",
|
| 796 |
+
"model.layers.41.mixer.conv1d.weight": "model_2.safetensors",
|
| 797 |
+
"model.layers.41.mixer.conv1d.bias": "model_2.safetensors",
|
| 798 |
+
"model.layers.41.mixer.dt_proj.bias": "model_2.safetensors",
|
| 799 |
+
"model.layers.41.mixer.in_proj.weight": "model_2.safetensors",
|
| 800 |
+
"model.layers.41.mixer.dt_proj.weight": "model_2.safetensors",
|
| 801 |
+
"model.layers.41.mixer.out_proj.weight": "model_2.safetensors",
|
| 802 |
+
"model.layers.41.mlp.gate_proj.weight": "model_2.safetensors",
|
| 803 |
+
"model.layers.41.mlp.up_proj.weight": "model_2.safetensors",
|
| 804 |
+
"model.layers.41.mlp.down_proj.weight": "model_2.safetensors",
|
| 805 |
+
"model.layers.42.mixer.A_log": "model_2.safetensors",
|
| 806 |
+
"model.layers.42.mixer.D": "model_2.safetensors",
|
| 807 |
+
"model.layers.42.input_layernorm.weight": "model_2.safetensors",
|
| 808 |
+
"model.layers.42.post_attention_layernorm.weight": "model_2.safetensors",
|
| 809 |
+
"model.layers.42.mixer.dt_in_proj.weight": "model_2.safetensors",
|
| 810 |
+
"model.layers.42.mixer.conv1d.weight": "model_2.safetensors",
|
| 811 |
+
"model.layers.42.mixer.conv1d.bias": "model_2.safetensors",
|
| 812 |
+
"model.layers.42.mixer.dt_proj.bias": "model_2.safetensors",
|
| 813 |
+
"model.layers.42.mixer.in_proj.weight": "model_2.safetensors",
|
| 814 |
+
"model.layers.42.mixer.dt_proj.weight": "model_2.safetensors",
|
| 815 |
+
"model.layers.42.mixer.out_proj.weight": "model_2.safetensors",
|
| 816 |
+
"model.layers.42.mlp.gate_proj.weight": "model_2.safetensors",
|
| 817 |
+
"model.layers.42.mlp.up_proj.weight": "model_2.safetensors",
|
| 818 |
+
"model.layers.42.mlp.down_proj.weight": "model_3.safetensors",
|
| 819 |
+
"model.layers.43.mixer.A_log": "model_3.safetensors",
|
| 820 |
+
"model.layers.43.mixer.D": "model_3.safetensors",
|
| 821 |
+
"model.layers.43.input_layernorm.weight": "model_3.safetensors",
|
| 822 |
+
"model.layers.43.post_attention_layernorm.weight": "model_3.safetensors",
|
| 823 |
+
"model.layers.43.mixer.dt_in_proj.weight": "model_3.safetensors",
|
| 824 |
+
"model.layers.43.mixer.conv1d.weight": "model_3.safetensors",
|
| 825 |
+
"model.layers.43.mixer.conv1d.bias": "model_3.safetensors",
|
| 826 |
+
"model.layers.43.mixer.dt_proj.bias": "model_3.safetensors",
|
| 827 |
+
"model.layers.43.mixer.in_proj.weight": "model_3.safetensors",
|
| 828 |
+
"model.layers.43.mixer.dt_proj.weight": "model_3.safetensors",
|
| 829 |
+
"model.layers.43.mixer.out_proj.weight": "model_3.safetensors",
|
| 830 |
+
"model.layers.43.mlp.gate_proj.weight": "model_3.safetensors",
|
| 831 |
+
"model.layers.43.mlp.up_proj.weight": "model_3.safetensors",
|
| 832 |
+
"model.layers.43.mlp.down_proj.weight": "model_3.safetensors",
|
| 833 |
+
"model.layers.44.mixer.A_log": "model_3.safetensors",
|
| 834 |
+
"model.layers.44.mixer.D": "model_3.safetensors",
|
| 835 |
+
"model.layers.44.input_layernorm.weight": "model_3.safetensors",
|
| 836 |
+
"model.layers.44.post_attention_layernorm.weight": "model_3.safetensors",
|
| 837 |
+
"model.layers.44.mixer.dt_in_proj.weight": "model_3.safetensors",
|
| 838 |
+
"model.layers.44.mixer.conv1d.weight": "model_3.safetensors",
|
| 839 |
+
"model.layers.44.mixer.conv1d.bias": "model_3.safetensors",
|
| 840 |
+
"model.layers.44.mixer.dt_proj.bias": "model_3.safetensors",
|
| 841 |
+
"model.layers.44.mixer.in_proj.weight": "model_3.safetensors",
|
| 842 |
+
"model.layers.44.mixer.dt_proj.weight": "model_3.safetensors",
|
| 843 |
+
"model.layers.44.mixer.out_proj.weight": "model_3.safetensors",
|
| 844 |
+
"model.layers.44.mlp.gate_proj.weight": "model_3.safetensors",
|
| 845 |
+
"model.layers.44.mlp.up_proj.weight": "model_3.safetensors",
|
| 846 |
+
"model.layers.44.mlp.down_proj.weight": "model_3.safetensors",
|
| 847 |
+
"model.layers.45.mixer.A_log": "model_3.safetensors",
|
| 848 |
+
"model.layers.45.mixer.D": "model_3.safetensors",
|
| 849 |
+
"model.layers.45.input_layernorm.weight": "model_3.safetensors",
|
| 850 |
+
"model.layers.45.post_attention_layernorm.weight": "model_3.safetensors",
|
| 851 |
+
"model.layers.45.mixer.dt_in_proj.weight": "model_3.safetensors",
|
| 852 |
+
"model.layers.45.mixer.conv1d.weight": "model_3.safetensors",
|
| 853 |
+
"model.layers.45.mixer.conv1d.bias": "model_3.safetensors",
|
| 854 |
+
"model.layers.45.mixer.dt_proj.bias": "model_3.safetensors",
|
| 855 |
+
"model.layers.45.mixer.in_proj.weight": "model_3.safetensors",
|
| 856 |
+
"model.layers.45.mixer.dt_proj.weight": "model_3.safetensors",
|
| 857 |
+
"model.layers.45.mixer.out_proj.weight": "model_3.safetensors",
|
| 858 |
+
"model.layers.45.mlp.gate_proj.weight": "model_3.safetensors",
|
| 859 |
+
"model.layers.45.mlp.up_proj.weight": "model_3.safetensors",
|
| 860 |
+
"model.layers.45.mlp.down_proj.weight": "model_3.safetensors",
|
| 861 |
+
"model.layers.46.mixer.A_log": "model_3.safetensors",
|
| 862 |
+
"model.layers.46.mixer.D": "model_3.safetensors",
|
| 863 |
+
"model.layers.46.input_layernorm.weight": "model_3.safetensors",
|
| 864 |
+
"model.layers.46.post_attention_layernorm.weight": "model_3.safetensors",
|
| 865 |
+
"model.layers.46.mixer.dt_in_proj.weight": "model_3.safetensors",
|
| 866 |
+
"model.layers.46.mixer.conv1d.weight": "model_3.safetensors",
|
| 867 |
+
"model.layers.46.mixer.conv1d.bias": "model_3.safetensors",
|
| 868 |
+
"model.layers.46.mixer.dt_proj.bias": "model_3.safetensors",
|
| 869 |
+
"model.layers.46.mixer.in_proj.weight": "model_3.safetensors",
|
| 870 |
+
"model.layers.46.mixer.dt_proj.weight": "model_3.safetensors",
|
| 871 |
+
"model.layers.46.mixer.out_proj.weight": "model_3.safetensors",
|
| 872 |
+
"model.layers.46.mlp.gate_proj.weight": "model_3.safetensors",
|
| 873 |
+
"model.layers.46.mlp.up_proj.weight": "model_3.safetensors",
|
| 874 |
+
"model.layers.46.mlp.down_proj.weight": "model_3.safetensors",
|
| 875 |
+
"model.layers.47.mixer.A_log": "model_3.safetensors",
|
| 876 |
+
"model.layers.47.mixer.D": "model_3.safetensors",
|
| 877 |
+
"model.layers.47.input_layernorm.weight": "model_3.safetensors",
|
| 878 |
+
"model.layers.47.post_attention_layernorm.weight": "model_3.safetensors",
|
| 879 |
+
"model.layers.47.mixer.dt_in_proj.weight": "model_3.safetensors",
|
| 880 |
+
"model.layers.47.mixer.conv1d.weight": "model_3.safetensors",
|
| 881 |
+
"model.layers.47.mixer.conv1d.bias": "model_3.safetensors",
|
| 882 |
+
"model.layers.47.mixer.dt_proj.bias": "model_3.safetensors",
|
| 883 |
+
"model.layers.47.mixer.in_proj.weight": "model_3.safetensors",
|
| 884 |
+
"model.layers.47.mixer.dt_proj.weight": "model_3.safetensors",
|
| 885 |
+
"model.layers.47.mixer.out_proj.weight": "model_3.safetensors",
|
| 886 |
+
"model.layers.47.mlp.gate_proj.weight": "model_3.safetensors",
|
| 887 |
+
"model.layers.47.mlp.up_proj.weight": "model_3.safetensors",
|
| 888 |
+
"model.layers.47.mlp.down_proj.weight": "model_3.safetensors",
|
| 889 |
+
"model.layers.48.input_layernorm.weight": "model_3.safetensors",
|
| 890 |
+
"model.layers.48.post_attention_layernorm.weight": "model_3.safetensors",
|
| 891 |
+
"model.layers.48.self_attn.q_proj.weight": "model_3.safetensors",
|
| 892 |
+
"model.layers.48.self_attn.k_proj.weight": "model_3.safetensors",
|
| 893 |
+
"model.layers.48.self_attn.v_proj.weight": "model_3.safetensors",
|
| 894 |
+
"model.layers.48.self_attn.o_proj.weight": "model_3.safetensors",
|
| 895 |
+
"model.layers.48.mlp.gate_proj.weight": "model_3.safetensors",
|
| 896 |
+
"model.layers.48.mlp.up_proj.weight": "model_3.safetensors",
|
| 897 |
+
"model.layers.48.mlp.down_proj.weight": "model_3.safetensors",
|
| 898 |
+
"model.layers.49.mixer.A_log": "model_3.safetensors",
|
| 899 |
+
"model.layers.49.mixer.D": "model_3.safetensors",
|
| 900 |
+
"model.layers.49.input_layernorm.weight": "model_3.safetensors",
|
| 901 |
+
"model.layers.49.post_attention_layernorm.weight": "model_3.safetensors",
|
| 902 |
+
"model.layers.49.mixer.dt_in_proj.weight": "model_3.safetensors",
|
| 903 |
+
"model.layers.49.mixer.conv1d.weight": "model_3.safetensors",
|
| 904 |
+
"model.layers.49.mixer.conv1d.bias": "model_3.safetensors",
|
| 905 |
+
"model.layers.49.mixer.dt_proj.bias": "model_3.safetensors",
|
| 906 |
+
"model.layers.49.mixer.in_proj.weight": "model_3.safetensors",
|
| 907 |
+
"model.layers.49.mixer.dt_proj.weight": "model_3.safetensors",
|
| 908 |
+
"model.layers.49.mixer.out_proj.weight": "model_3.safetensors",
|
| 909 |
+
"model.layers.49.mlp.gate_proj.weight": "model_3.safetensors",
|
| 910 |
+
"model.layers.49.mlp.up_proj.weight": "model_3.safetensors",
|
| 911 |
+
"model.layers.49.mlp.down_proj.weight": "model_3.safetensors",
|
| 912 |
+
"model.norm.weight": "model_3.safetensors",
|
| 913 |
+
"lm_head.weight": "model_3.safetensors"
|
| 914 |
+
}
|
| 915 |
+
}
|
model_0.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:54ab7d8a72b9ed8dddb280c6b894a0a7c6404a6545cf77821e075708353ec122
|
| 3 |
+
size 17341952504
|
model_1.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3cf3c4655d0a8f554e50c3c8001f92f681edb1a68a18eae6f4c55857756ed141
|
| 3 |
+
size 17415079472
|
model_2.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:03913e3b49ba8415e659d23f28e30f55a973ff5b669ef8e323226331e7f98e84
|
| 3 |
+
size 17217537984
|
model_3.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e5288d374646a02d5bec8810577e6955c90e810fd958bcd1c835d17a1d45e755
|
| 3 |
+
size 11188268104
|
modeling_apriel_h.py
ADDED
|
@@ -0,0 +1,908 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import math
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import Any, Optional, Union
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
|
| 9 |
+
from .configuration_apriel_h import AprielHConfig
|
| 10 |
+
from einops import rearrange, repeat
|
| 11 |
+
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn
|
| 12 |
+
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
|
| 13 |
+
from torch import nn
|
| 14 |
+
from transformers import GenerationMixin
|
| 15 |
+
from transformers.cache_utils import Cache, DynamicCache
|
| 16 |
+
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
| 17 |
+
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
| 18 |
+
from transformers.modeling_utils import PreTrainedModel
|
| 19 |
+
from transformers.models.mistral.modeling_mistral import MistralDecoderLayer, MistralMLP, MistralModel, MistralRMSNorm
|
| 20 |
+
from transformers.processing_utils import Unpack
|
| 21 |
+
from transformers.utils import LossKwargs, can_return_tuple, logging
|
| 22 |
+
from transformers.utils.generic import ModelOutput
|
| 23 |
+
|
| 24 |
+
logger = logging.get_logger(__name__)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update))
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
| 31 |
+
"""
|
| 32 |
+
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
| 33 |
+
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
| 34 |
+
"""
|
| 35 |
+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
| 36 |
+
if n_rep == 1:
|
| 37 |
+
return hidden_states
|
| 38 |
+
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
| 39 |
+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
# Copied from https://github.com/huggingface/transformers/blob/main/src/transformers/models/jamba/modeling_jamba.py
|
| 43 |
+
class HybridMambaAttentionDynamicCache(DynamicCache):
|
| 44 |
+
"""
|
| 45 |
+
A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache
|
| 46 |
+
(which has a constant shape regardless of seq_len).
|
| 47 |
+
This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states`
|
| 48 |
+
and `ssm_states` for mamba cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor
|
| 49 |
+
For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`,
|
| 50 |
+
while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors).
|
| 51 |
+
For mamba layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors),
|
| 52 |
+
while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`,
|
| 53 |
+
and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`.
|
| 54 |
+
"""
|
| 55 |
+
|
| 56 |
+
def __init__(self, config: AprielHConfig, batch_size, dtype=torch.float16, device=None):
|
| 57 |
+
super().__init__()
|
| 58 |
+
self.dtype = dtype
|
| 59 |
+
self.hybrid_override_pattern = config.hybrid_block_layout
|
| 60 |
+
self.has_previous_state = False # only used by mamba
|
| 61 |
+
intermediate_size = (
|
| 62 |
+
config.ssm_cfg["d_inner"]
|
| 63 |
+
if config.ssm_cfg["d_inner"] is not None
|
| 64 |
+
else config.ssm_cfg["expand"] * config.hidden_size
|
| 65 |
+
)
|
| 66 |
+
ssm_state_size = config.ssm_cfg["d_state"]
|
| 67 |
+
conv_kernel_size = config.ssm_cfg["d_conv"]
|
| 68 |
+
self.n_qk_heads = config.ssm_cfg["n_qk_heads"]
|
| 69 |
+
self.num_C_head = intermediate_size // ssm_state_size # mamba2
|
| 70 |
+
assert intermediate_size % self.n_qk_heads == 0, "d_inner must be divisible by n_qk_heads"
|
| 71 |
+
self.head_d = intermediate_size // self.n_qk_heads
|
| 72 |
+
self.conv_states = []
|
| 73 |
+
self.ssm_states = []
|
| 74 |
+
self.transformer_layers = []
|
| 75 |
+
for i in range(config.num_hidden_layers):
|
| 76 |
+
if self.hybrid_override_pattern[i] == "m2d":
|
| 77 |
+
# Mamba layer
|
| 78 |
+
self.conv_states += [
|
| 79 |
+
torch.zeros(
|
| 80 |
+
batch_size,
|
| 81 |
+
conv_kernel_size,
|
| 82 |
+
intermediate_size + 2 * self.n_qk_heads * ssm_state_size,
|
| 83 |
+
device=device,
|
| 84 |
+
dtype=dtype,
|
| 85 |
+
).transpose(1, 2)
|
| 86 |
+
]
|
| 87 |
+
self.ssm_states += [
|
| 88 |
+
torch.zeros(batch_size, self.n_qk_heads, self.head_d, ssm_state_size, device=device, dtype=dtype)
|
| 89 |
+
]
|
| 90 |
+
elif self.hybrid_override_pattern[i] == "m2":
|
| 91 |
+
if "repeat_kv_before_conv" in config.ssm_cfg:
|
| 92 |
+
assert (
|
| 93 |
+
config.ssm_cfg["repeat_kv_before_conv"] == True
|
| 94 |
+
), "Only support repeat_kv_before_conv=True for m2 for now"
|
| 95 |
+
|
| 96 |
+
self.conv_states += [
|
| 97 |
+
torch.zeros(
|
| 98 |
+
batch_size,
|
| 99 |
+
intermediate_size,
|
| 100 |
+
conv_kernel_size,
|
| 101 |
+
device=device,
|
| 102 |
+
dtype=dtype,
|
| 103 |
+
)
|
| 104 |
+
]
|
| 105 |
+
self.ssm_states += [
|
| 106 |
+
torch.zeros(
|
| 107 |
+
batch_size,
|
| 108 |
+
self.num_C_head,
|
| 109 |
+
intermediate_size // self.num_C_head,
|
| 110 |
+
ssm_state_size,
|
| 111 |
+
device=device,
|
| 112 |
+
dtype=dtype,
|
| 113 |
+
)
|
| 114 |
+
]
|
| 115 |
+
else:
|
| 116 |
+
# Attention or MLP layer
|
| 117 |
+
self.conv_states += [torch.tensor([[]] * batch_size, device=device)]
|
| 118 |
+
self.ssm_states += [torch.tensor([[]] * batch_size, device=device)]
|
| 119 |
+
self.transformer_layers.append(i)
|
| 120 |
+
|
| 121 |
+
self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)]
|
| 122 |
+
self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)]
|
| 123 |
+
|
| 124 |
+
def update(
|
| 125 |
+
self,
|
| 126 |
+
key_states: torch.Tensor,
|
| 127 |
+
value_states: torch.Tensor,
|
| 128 |
+
layer_idx: int,
|
| 129 |
+
cache_kwargs: Optional[dict[str, Any]] = None,
|
| 130 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 131 |
+
# Update the cache
|
| 132 |
+
if self.key_cache[layer_idx].shape[-1] == 0:
|
| 133 |
+
self.key_cache[layer_idx] = key_states
|
| 134 |
+
self.value_cache[layer_idx] = value_states
|
| 135 |
+
else:
|
| 136 |
+
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=2)
|
| 137 |
+
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=2)
|
| 138 |
+
|
| 139 |
+
return self.key_cache[layer_idx], self.value_cache[layer_idx]
|
| 140 |
+
|
| 141 |
+
def reorder_cache(self, beam_idx: torch.LongTensor):
|
| 142 |
+
"""Reorders the cache for beam search, given the selected beam indices."""
|
| 143 |
+
for layer_idx in range(len(self.key_cache)):
|
| 144 |
+
device = self.key_cache[layer_idx].device
|
| 145 |
+
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
|
| 146 |
+
device = self.value_cache[layer_idx].device
|
| 147 |
+
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
|
| 148 |
+
|
| 149 |
+
device = self.conv_states[layer_idx].device
|
| 150 |
+
self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx.to(device))
|
| 151 |
+
device = self.ssm_states[layer_idx].device
|
| 152 |
+
self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(0, beam_idx.to(device))
|
| 153 |
+
|
| 154 |
+
def to_legacy_cache(self) -> tuple[tuple[torch.Tensor], tuple[torch.Tensor]]:
|
| 155 |
+
raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.")
|
| 156 |
+
|
| 157 |
+
@classmethod
|
| 158 |
+
def from_legacy_cache(cls, past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None) -> "DynamicCache":
|
| 159 |
+
raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.")
|
| 160 |
+
|
| 161 |
+
# Copied from modeling_mamba2.py
|
| 162 |
+
def update_conv_state(
|
| 163 |
+
self, layer_idx: int, new_conv_state: torch.Tensor, cache_init: bool = False
|
| 164 |
+
) -> torch.Tensor:
|
| 165 |
+
if cache_init:
|
| 166 |
+
self.conv_states[layer_idx] = new_conv_state.to(self.conv_states.device)
|
| 167 |
+
else:
|
| 168 |
+
self.conv_states[layer_idx] = self.conv_states[layer_idx].roll(shifts=-1, dims=-1)
|
| 169 |
+
self.conv_states[layer_idx][:, :, -1] = new_conv_state[:, 0, :].to(self.conv_states.device)
|
| 170 |
+
return self.conv_states[layer_idx]
|
| 171 |
+
|
| 172 |
+
def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor):
|
| 173 |
+
self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states.device)
|
| 174 |
+
return self.ssm_states[layer_idx]
|
| 175 |
+
|
| 176 |
+
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
| 177 |
+
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
|
| 178 |
+
# take any layer that contains cache and not empty tensor
|
| 179 |
+
layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx
|
| 180 |
+
if len(self.key_cache) <= layer_idx:
|
| 181 |
+
return 0
|
| 182 |
+
is_empty_layer = (
|
| 183 |
+
len(self.key_cache) == 0 # no cache in any layer
|
| 184 |
+
or len(self.key_cache) <= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it
|
| 185 |
+
or not self.key_cache[layer_idx].numel() # the layer has no cache
|
| 186 |
+
)
|
| 187 |
+
return self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0
|
| 188 |
+
# return self.key_cache[layer_idx].shape[-2]
|
| 189 |
+
|
| 190 |
+
def reset(self):
|
| 191 |
+
self.conv_states.zero_()
|
| 192 |
+
self.ssm_states.zero_()
|
| 193 |
+
|
| 194 |
+
def to_legacy_cache(self) -> tuple[tuple[torch.Tensor], tuple[torch.Tensor]]:
|
| 195 |
+
raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.")
|
| 196 |
+
|
| 197 |
+
@classmethod
|
| 198 |
+
def from_legacy_cache(cls, past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None) -> "DynamicCache":
|
| 199 |
+
raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.")
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
@dataclass
|
| 203 |
+
class AprielHybridCausalOutput(ModelOutput):
|
| 204 |
+
"""Custom output class for MambaLMHeadModel."""
|
| 205 |
+
|
| 206 |
+
loss: Optional[torch.FloatTensor] = None
|
| 207 |
+
logits: Optional[torch.FloatTensor] = None
|
| 208 |
+
all_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
|
| 209 |
+
last_hidden_state: Optional[torch.FloatTensor] = None
|
| 210 |
+
attention_weights: Optional[torch.FloatTensor] = None
|
| 211 |
+
past_key_values: Optional[Cache] = None
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
def segsum(x):
|
| 215 |
+
"""More stable segment sum calculation."""
|
| 216 |
+
# [1, 2, 3]
|
| 217 |
+
T = x.size(-1)
|
| 218 |
+
x = repeat(x, "... d -> ... d e", e=T)
|
| 219 |
+
# [[1, 1, 1], [2, 2, 2], [3, 3, 3]]
|
| 220 |
+
mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=-1)
|
| 221 |
+
x = x.masked_fill(~mask, 0)
|
| 222 |
+
# [[0, 0, 0], [2, 0, 0], [3, 3, 0]]
|
| 223 |
+
x_segsum = torch.cumsum(x, dim=-2)
|
| 224 |
+
# [[0, 0, 0], [2, 0, 0], [5, 3, 0]]
|
| 225 |
+
mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0)
|
| 226 |
+
x_segsum = x_segsum.masked_fill(~mask, -torch.inf)
|
| 227 |
+
return x_segsum
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
def materialize_mixer(A_log, B, C, D):
|
| 231 |
+
"""
|
| 232 |
+
Since the transfer matrix will be equated to the attention matrix,
|
| 233 |
+
we need to support the form: torch.matmul(attn_weights, value_states).
|
| 234 |
+
Thus, y = torch.matmul(T, X)
|
| 235 |
+
Arguments:
|
| 236 |
+
A_log: (batch, length, n_heads)
|
| 237 |
+
B: (batch, length, n_heads, d_state)
|
| 238 |
+
C: (batch, length, n_heads, d_state)
|
| 239 |
+
Return:
|
| 240 |
+
T: (batch, n_heads, length, length)
|
| 241 |
+
"""
|
| 242 |
+
batch_size, length, n_heads, d_state = B.shape
|
| 243 |
+
assert A_log.shape == (batch_size, length, n_heads)
|
| 244 |
+
assert B.shape == C.shape == (batch_size, length, n_heads, d_state)
|
| 245 |
+
|
| 246 |
+
# Compute:
|
| 247 |
+
A_log = rearrange(-F.softplus(A_log), "b l h -> b h l")
|
| 248 |
+
powers = torch.exp(segsum(A_log))
|
| 249 |
+
T = torch.einsum("blhn,bshn,bhls->bhsl", C, B, powers)
|
| 250 |
+
|
| 251 |
+
# Add D:
|
| 252 |
+
if D is not None:
|
| 253 |
+
T[:, :, torch.arange(length), torch.arange(length)] += D.view(1, n_heads, 1)
|
| 254 |
+
|
| 255 |
+
T = rearrange(T, "b h z l -> b h l z")
|
| 256 |
+
return T
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
def apply_mask_to_padding_states(hidden_states, attention_mask):
|
| 260 |
+
"""
|
| 261 |
+
Tunes out the hidden states for padding tokens, see https://github.com/state-spaces/mamba/issues/66
|
| 262 |
+
"""
|
| 263 |
+
if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
|
| 264 |
+
dtype = hidden_states.dtype
|
| 265 |
+
hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)
|
| 266 |
+
|
| 267 |
+
return hidden_states
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
class Mamba(nn.Module):
|
| 271 |
+
def __init__(
|
| 272 |
+
self,
|
| 273 |
+
d_model,
|
| 274 |
+
d_inner,
|
| 275 |
+
d_xb=None,
|
| 276 |
+
d_state=16,
|
| 277 |
+
d_conv=4,
|
| 278 |
+
expand=2,
|
| 279 |
+
dt_rank="auto",
|
| 280 |
+
dt_min=0.001,
|
| 281 |
+
dt_max=0.1,
|
| 282 |
+
dt_init="random",
|
| 283 |
+
dt_scale=1.0,
|
| 284 |
+
dt_init_floor=1e-4,
|
| 285 |
+
repeat_kv_before_conv=True,
|
| 286 |
+
conv_bias=True,
|
| 287 |
+
bias=False,
|
| 288 |
+
dt_proj_bias=True,
|
| 289 |
+
use_fast_path=True, # Fused kernel options
|
| 290 |
+
layer_idx=None,
|
| 291 |
+
device=None,
|
| 292 |
+
dtype=None,
|
| 293 |
+
**kwargs,
|
| 294 |
+
):
|
| 295 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 296 |
+
super().__init__()
|
| 297 |
+
self.d_model = d_model
|
| 298 |
+
self.d_xb = d_xb if d_xb is not None else d_model
|
| 299 |
+
self.d_state = d_state
|
| 300 |
+
self.d_conv = d_conv
|
| 301 |
+
self.expand = expand
|
| 302 |
+
self.d_inner = d_inner if d_inner is not None else int(self.expand * self.d_model)
|
| 303 |
+
self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
|
| 304 |
+
self.use_fast_path = use_fast_path
|
| 305 |
+
self.layer_idx = layer_idx
|
| 306 |
+
self.repeat_kv_before_conv = repeat_kv_before_conv
|
| 307 |
+
|
| 308 |
+
if self.repeat_kv_before_conv:
|
| 309 |
+
self.conv1d = nn.Conv1d(
|
| 310 |
+
in_channels=self.d_inner,
|
| 311 |
+
out_channels=self.d_inner,
|
| 312 |
+
bias=conv_bias,
|
| 313 |
+
kernel_size=d_conv,
|
| 314 |
+
groups=self.d_inner,
|
| 315 |
+
padding=d_conv - 1,
|
| 316 |
+
**factory_kwargs,
|
| 317 |
+
)
|
| 318 |
+
else:
|
| 319 |
+
self.conv1d = nn.Conv1d(
|
| 320 |
+
in_channels=self.d_xb,
|
| 321 |
+
out_channels=self.d_xb,
|
| 322 |
+
bias=conv_bias,
|
| 323 |
+
kernel_size=d_conv,
|
| 324 |
+
groups=self.d_xb,
|
| 325 |
+
padding=d_conv - 1,
|
| 326 |
+
**factory_kwargs,
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
self.activation = "silu"
|
| 330 |
+
self.act = nn.SiLU()
|
| 331 |
+
|
| 332 |
+
self.num_xb_head = self.d_xb // self.d_state
|
| 333 |
+
self.num_C_head = self.d_inner // self.d_state
|
| 334 |
+
self.repeat_group = self.num_C_head // self.num_xb_head
|
| 335 |
+
|
| 336 |
+
self.in_proj = nn.Linear(self.d_model, 2 * self.d_xb + 2 * self.d_inner, bias=bias, **factory_kwargs)
|
| 337 |
+
self.dt_in_proj = nn.Linear(self.d_model, self.dt_rank, bias=bias, **factory_kwargs)
|
| 338 |
+
self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=dt_proj_bias, **factory_kwargs)
|
| 339 |
+
|
| 340 |
+
# Initialize special dt projection to preserve variance at initialization
|
| 341 |
+
dt_init_std = self.dt_rank**-0.5 * dt_scale
|
| 342 |
+
if dt_init == "constant":
|
| 343 |
+
nn.init.constant_(self.dt_proj.weight, dt_init_std)
|
| 344 |
+
elif dt_init == "random":
|
| 345 |
+
nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
|
| 346 |
+
else:
|
| 347 |
+
raise NotImplementedError
|
| 348 |
+
|
| 349 |
+
# Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
|
| 350 |
+
dt = torch.exp(
|
| 351 |
+
torch.rand(self.d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min)
|
| 352 |
+
).clamp(min=dt_init_floor)
|
| 353 |
+
# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
|
| 354 |
+
inv_dt = dt + torch.log(-torch.expm1(-dt))
|
| 355 |
+
with torch.no_grad():
|
| 356 |
+
self.dt_proj.bias.copy_(inv_dt)
|
| 357 |
+
# Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
|
| 358 |
+
self.dt_proj.bias._no_reinit = True
|
| 359 |
+
|
| 360 |
+
# S4D real initialization
|
| 361 |
+
A = repeat(
|
| 362 |
+
torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),
|
| 363 |
+
"n -> d n",
|
| 364 |
+
d=self.d_inner,
|
| 365 |
+
).contiguous()
|
| 366 |
+
A_log = torch.log(A) # Keep A_log in fp32
|
| 367 |
+
self.A_log = nn.Parameter(A_log)
|
| 368 |
+
self.A_log._no_weight_decay = True
|
| 369 |
+
|
| 370 |
+
# D "skip" parameter
|
| 371 |
+
self.D = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32
|
| 372 |
+
self.D._no_weight_decay = True
|
| 373 |
+
|
| 374 |
+
self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
|
| 375 |
+
|
| 376 |
+
def forward(
|
| 377 |
+
self,
|
| 378 |
+
hidden_states: torch.Tensor,
|
| 379 |
+
past_key_value: Optional[HybridMambaAttentionDynamicCache] = None,
|
| 380 |
+
mamba_mask: Optional[torch.Tensor] = None,
|
| 381 |
+
return_mixer_matrix=False,
|
| 382 |
+
**kwargs,
|
| 383 |
+
):
|
| 384 |
+
"""
|
| 385 |
+
hidden_states: (B, L, D)
|
| 386 |
+
Returns: same shape as hidden_states
|
| 387 |
+
"""
|
| 388 |
+
assert is_fast_path_available and "cuda" in self.in_proj.weight.device.type, "Only support fast path on cuda"
|
| 389 |
+
cache_position = kwargs.get("cache_position", None)
|
| 390 |
+
batch, seqlen, dim = hidden_states.shape
|
| 391 |
+
|
| 392 |
+
ssm_state, conv_state = None, None
|
| 393 |
+
use_precomputed_states = False
|
| 394 |
+
|
| 395 |
+
#########################################################
|
| 396 |
+
# Quick and dirty to work with CG
|
| 397 |
+
if "inference_params" in kwargs:
|
| 398 |
+
seqlen_offset = kwargs["inference_params"].seqlen_offset
|
| 399 |
+
if seqlen_offset > 0:
|
| 400 |
+
use_precomputed_states = True
|
| 401 |
+
else:
|
| 402 |
+
seqlen_offset = kwargs.get("seqlen_offset", cache_position[0]) if cache_position is not None else 0
|
| 403 |
+
use_precomputed_states = (
|
| 404 |
+
past_key_value is not None
|
| 405 |
+
and past_key_value.has_previous_state
|
| 406 |
+
and seqlen == 1
|
| 407 |
+
and past_key_value.conv_states[self.layer_idx].shape[0]
|
| 408 |
+
== past_key_value.ssm_states[self.layer_idx].shape[0]
|
| 409 |
+
== batch
|
| 410 |
+
and cache_position is not None
|
| 411 |
+
and seqlen_offset > 0
|
| 412 |
+
)
|
| 413 |
+
#########################################################
|
| 414 |
+
|
| 415 |
+
ssm_state, conv_state = self._get_states_from_cache(past_key_value, batch)
|
| 416 |
+
if use_precomputed_states:
|
| 417 |
+
out, _, _ = self.step(hidden_states, conv_state, ssm_state)
|
| 418 |
+
return {"hidden_states": out}
|
| 419 |
+
|
| 420 |
+
outputs = {}
|
| 421 |
+
A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
|
| 422 |
+
|
| 423 |
+
zxbc = self.in_proj(hidden_states)
|
| 424 |
+
z, x, B, C = torch.split(
|
| 425 |
+
zxbc,
|
| 426 |
+
[
|
| 427 |
+
self.d_inner,
|
| 428 |
+
self.d_xb,
|
| 429 |
+
self.d_xb,
|
| 430 |
+
self.d_inner,
|
| 431 |
+
],
|
| 432 |
+
dim=-1,
|
| 433 |
+
)
|
| 434 |
+
|
| 435 |
+
x = rearrange(x, "b l d -> b d l")
|
| 436 |
+
z = rearrange(z, "b l d -> b d l")
|
| 437 |
+
|
| 438 |
+
B = rearrange(B, "b l (n_group dstate) -> b n_group l dstate", dstate=self.d_state)
|
| 439 |
+
B = repeat_kv(B, self.repeat_group) # B, n_group, L, H
|
| 440 |
+
B = rearrange(B, "b n_group l dstate -> b n_group dstate l").contiguous()
|
| 441 |
+
C = rearrange(C, "b l (n_group dstate) -> b n_group dstate l", dstate=self.d_state).contiguous()
|
| 442 |
+
|
| 443 |
+
dt = self.dt_proj(self.dt_in_proj(hidden_states)) # B, L, d_inner
|
| 444 |
+
dt = rearrange(dt, "b l d -> b d l") # B, d_inner, L
|
| 445 |
+
|
| 446 |
+
if self.repeat_kv_before_conv:
|
| 447 |
+
x = rearrange(x, "b (n_group dstate) l -> b n_group l dstate", dstate=self.d_state)
|
| 448 |
+
x = repeat_kv(x, self.repeat_group)
|
| 449 |
+
x = rearrange(x, "b n_group l dstate -> b (n_group dstate) l")
|
| 450 |
+
|
| 451 |
+
# Compute short convolution
|
| 452 |
+
if conv_state is not None:
|
| 453 |
+
# If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
|
| 454 |
+
# Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
|
| 455 |
+
# Update state (B D W)
|
| 456 |
+
conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0)))
|
| 457 |
+
if causal_conv1d_fn is None:
|
| 458 |
+
x = self.act(self.conv1d(x)[..., :seqlen]).transpose(1, 2)
|
| 459 |
+
else:
|
| 460 |
+
assert self.activation in ["silu", "swish"]
|
| 461 |
+
x = causal_conv1d_fn(
|
| 462 |
+
x=x,
|
| 463 |
+
weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
|
| 464 |
+
bias=self.conv1d.bias,
|
| 465 |
+
activation=self.activation,
|
| 466 |
+
)
|
| 467 |
+
|
| 468 |
+
if not self.repeat_kv_before_conv:
|
| 469 |
+
x = rearrange(x, "b (n_group dstate) l -> b n_group l dstate", dstate=self.d_state)
|
| 470 |
+
x = repeat_kv(x, self.repeat_group)
|
| 471 |
+
x = rearrange(x, "b n_group l dstate -> b (n_group dstate) l")
|
| 472 |
+
|
| 473 |
+
y = selective_scan_fn(
|
| 474 |
+
x,
|
| 475 |
+
dt,
|
| 476 |
+
A,
|
| 477 |
+
B,
|
| 478 |
+
C,
|
| 479 |
+
self.D.float(),
|
| 480 |
+
z=z,
|
| 481 |
+
delta_bias=self.dt_proj.bias.float(),
|
| 482 |
+
delta_softplus=True,
|
| 483 |
+
return_last_state=(ssm_state is not None),
|
| 484 |
+
)
|
| 485 |
+
|
| 486 |
+
if ssm_state is not None:
|
| 487 |
+
y, last_state = y
|
| 488 |
+
ssm_state.copy_(rearrange(last_state, "b (h d) n -> b h d n", h=self.num_C_head))
|
| 489 |
+
|
| 490 |
+
y = rearrange(y, "b d l -> b l d")
|
| 491 |
+
out = self.out_proj(y)
|
| 492 |
+
|
| 493 |
+
outputs["hidden_states"] = out[:, :seqlen, :]
|
| 494 |
+
return outputs
|
| 495 |
+
|
| 496 |
+
def step(self, hidden_states, conv_state, ssm_state):
|
| 497 |
+
dtype = hidden_states.dtype
|
| 498 |
+
assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now"
|
| 499 |
+
|
| 500 |
+
hidden_states_input = hidden_states.squeeze(1)
|
| 501 |
+
|
| 502 |
+
A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
|
| 503 |
+
|
| 504 |
+
zxbc = self.in_proj(hidden_states_input)
|
| 505 |
+
z, x, B, C = torch.split(zxbc, [self.d_inner, self.d_xb, self.d_xb, self.d_inner], dim=-1)
|
| 506 |
+
|
| 507 |
+
B = rearrange(B, "b (n_group dstate) -> b n_group dstate", dstate=self.d_state)
|
| 508 |
+
B = torch.repeat_interleave(B, dim=1, repeats=self.repeat_group)
|
| 509 |
+
C = rearrange(C, "b (n_group dstate) -> b n_group dstate", dstate=self.d_state).contiguous()
|
| 510 |
+
|
| 511 |
+
dt = self.dt_proj(self.dt_in_proj(hidden_states_input)) # B, d_inner
|
| 512 |
+
|
| 513 |
+
if self.repeat_kv_before_conv:
|
| 514 |
+
x = rearrange(x, "b (n_group dstate) -> b n_group dstate", dstate=self.d_state)
|
| 515 |
+
x = torch.repeat_interleave(x, dim=1, repeats=self.repeat_group)
|
| 516 |
+
x = rearrange(x, "b n_group dstate -> b (n_group dstate)")
|
| 517 |
+
|
| 518 |
+
# Conv step
|
| 519 |
+
if causal_conv1d_update is None:
|
| 520 |
+
# Update state (B D W)
|
| 521 |
+
conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1))
|
| 522 |
+
conv_state[:, :, -1] = x
|
| 523 |
+
x = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D)
|
| 524 |
+
if self.conv1d.bias is not None:
|
| 525 |
+
x = x + self.conv1d.bias
|
| 526 |
+
x = self.act(x).to(dtype=dtype)
|
| 527 |
+
else:
|
| 528 |
+
x = causal_conv1d_update(
|
| 529 |
+
x,
|
| 530 |
+
conv_state,
|
| 531 |
+
rearrange(self.conv1d.weight, "d 1 w -> d w"),
|
| 532 |
+
self.conv1d.bias,
|
| 533 |
+
self.activation,
|
| 534 |
+
)
|
| 535 |
+
|
| 536 |
+
if not self.repeat_kv_before_conv:
|
| 537 |
+
x = rearrange(x, "b (n_group dstate) -> b n_group dstate", dstate=self.d_state)
|
| 538 |
+
x = torch.repeat_interleave(x, dim=1, repeats=self.repeat_group)
|
| 539 |
+
x = rearrange(x, "b n_group dstate -> b (n_group dstate)")
|
| 540 |
+
|
| 541 |
+
x = rearrange(x, "b (h d) -> b h d", h=self.num_C_head)
|
| 542 |
+
dt = rearrange(dt, "b (h d) -> b h d", h=self.num_C_head)
|
| 543 |
+
A = rearrange(A, "(h d) n -> h d n", h=self.num_C_head)
|
| 544 |
+
D = rearrange(self.D, "(h d) -> h d", h=self.num_C_head)
|
| 545 |
+
z = rearrange(z, "b (h d) -> b h d", h=self.num_C_head)
|
| 546 |
+
dt_bias = rearrange(self.dt_proj.bias, "(h d) -> h d", h=self.num_C_head)
|
| 547 |
+
|
| 548 |
+
# SSM step
|
| 549 |
+
assert selective_state_update is not None
|
| 550 |
+
y = selective_state_update(ssm_state, x, dt, A, B, C, D, z=z, dt_bias=dt_bias, dt_softplus=True)
|
| 551 |
+
y = rearrange(y, "b h d -> b (h d)")
|
| 552 |
+
out = self.out_proj(y)
|
| 553 |
+
|
| 554 |
+
return out.unsqueeze(1), conv_state, ssm_state
|
| 555 |
+
|
| 556 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
| 557 |
+
device = self.out_proj.weight.device
|
| 558 |
+
conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype
|
| 559 |
+
if self.repeat_kv_before_conv:
|
| 560 |
+
conv_state = torch.zeros(batch_size, self.d_inner, self.d_conv, device=device, dtype=conv_dtype)
|
| 561 |
+
else:
|
| 562 |
+
conv_state = torch.zeros(batch_size, self.d_xb, self.d_conv, device=device, dtype=conv_dtype)
|
| 563 |
+
ssm_dtype = self.dt_proj.weight.dtype if dtype is None else dtype
|
| 564 |
+
ssm_state = torch.zeros(
|
| 565 |
+
batch_size, self.num_C_head, self.d_inner // self.num_C_head, self.d_state, device=device, dtype=ssm_dtype
|
| 566 |
+
)
|
| 567 |
+
return conv_state, ssm_state
|
| 568 |
+
|
| 569 |
+
def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False):
|
| 570 |
+
"""
|
| 571 |
+
conv_state: (batch, d_conv, conv1d.weight.shape[0])
|
| 572 |
+
ssm_state: (batch, n_qk_heads, headdim, d_state)
|
| 573 |
+
"""
|
| 574 |
+
assert self.layer_idx is not None
|
| 575 |
+
|
| 576 |
+
# Get states
|
| 577 |
+
ssm_states = inference_params.ssm_states[self.layer_idx]
|
| 578 |
+
conv_states = inference_params.conv_states[self.layer_idx]
|
| 579 |
+
if initialize_states:
|
| 580 |
+
ssm_states.zero_()
|
| 581 |
+
conv_states.zero_()
|
| 582 |
+
return ssm_states, conv_states
|
| 583 |
+
|
| 584 |
+
|
| 585 |
+
class AprielSSMM2DecoderLayer(nn.Module):
|
| 586 |
+
_mixer_class = Mamba
|
| 587 |
+
|
| 588 |
+
def __init__(self, config: AprielHConfig, layer_idx: int, device=None, dtype=None, **kwargs):
|
| 589 |
+
super().__init__(**kwargs)
|
| 590 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 591 |
+
self.hidden_size = config.hidden_size
|
| 592 |
+
|
| 593 |
+
self.mixer = self._mixer_class(
|
| 594 |
+
d_model=config.hidden_size,
|
| 595 |
+
layer_idx=layer_idx,
|
| 596 |
+
**config.ssm_cfg,
|
| 597 |
+
**factory_kwargs,
|
| 598 |
+
)
|
| 599 |
+
|
| 600 |
+
self.mlp = MistralMLP(config)
|
| 601 |
+
self.input_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 602 |
+
self.post_attention_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 603 |
+
|
| 604 |
+
def forward(
|
| 605 |
+
self, hidden_states: torch.Tensor, **kwargs
|
| 606 |
+
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
| 607 |
+
|
| 608 |
+
outputs = {}
|
| 609 |
+
residual = hidden_states
|
| 610 |
+
|
| 611 |
+
hidden_states = self.input_layernorm(hidden_states)
|
| 612 |
+
|
| 613 |
+
mixer_outputs = self.mixer(
|
| 614 |
+
hidden_states,
|
| 615 |
+
**kwargs,
|
| 616 |
+
)
|
| 617 |
+
|
| 618 |
+
hidden_states = mixer_outputs["hidden_states"].to(residual.dtype) + residual
|
| 619 |
+
|
| 620 |
+
# Fully Connected
|
| 621 |
+
residual = hidden_states
|
| 622 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 623 |
+
hidden_states = self.mlp(hidden_states)
|
| 624 |
+
hidden_states = residual + hidden_states
|
| 625 |
+
|
| 626 |
+
outputs = (hidden_states,)
|
| 627 |
+
|
| 628 |
+
return outputs
|
| 629 |
+
|
| 630 |
+
|
| 631 |
+
class AprielHybridIdentity(nn.Module):
|
| 632 |
+
def __init__(self, config: AprielHConfig):
|
| 633 |
+
super().__init__()
|
| 634 |
+
self.config = config
|
| 635 |
+
|
| 636 |
+
def forward(self, hidden_states: torch.Tensor, **kwargs):
|
| 637 |
+
return (hidden_states,)
|
| 638 |
+
|
| 639 |
+
|
| 640 |
+
class AprielHModel(MistralModel):
|
| 641 |
+
"""
|
| 642 |
+
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`AprielDecoderLayer`, `AprielSSMDecoderLayer`]
|
| 643 |
+
Args:
|
| 644 |
+
config: AprielSSMHybridConfig
|
| 645 |
+
"""
|
| 646 |
+
|
| 647 |
+
def __init__(self, config: AprielHConfig, **kwargs):
|
| 648 |
+
config_copy = copy.deepcopy(config)
|
| 649 |
+
config_copy.num_hidden_layers = 0
|
| 650 |
+
super().__init__(config_copy, **kwargs)
|
| 651 |
+
self.config = config
|
| 652 |
+
blocks = []
|
| 653 |
+
logger.info(f"Loading hyubrid model with the following layout: {config.hybrid_block_layout}")
|
| 654 |
+
for layer_idx, type in enumerate(config.hybrid_block_layout):
|
| 655 |
+
if type == "m2":
|
| 656 |
+
blocks.append(AprielSSMM2DecoderLayer(config, layer_idx))
|
| 657 |
+
elif type == "t":
|
| 658 |
+
blocks.append(MistralDecoderLayer(config, layer_idx))
|
| 659 |
+
elif type == "i":
|
| 660 |
+
blocks.append(AprielHybridIdentity(config))
|
| 661 |
+
else:
|
| 662 |
+
raise ValueError(f"Invalid block type: {type}")
|
| 663 |
+
self.layers = nn.ModuleList(blocks)
|
| 664 |
+
|
| 665 |
+
# Initialize weights and apply final processing
|
| 666 |
+
self.post_init()
|
| 667 |
+
|
| 668 |
+
@can_return_tuple
|
| 669 |
+
def forward(
|
| 670 |
+
self,
|
| 671 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 672 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 673 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 674 |
+
past_key_values: Optional[Cache] = None,
|
| 675 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 676 |
+
use_cache: Optional[bool] = None,
|
| 677 |
+
output_attentions: Optional[bool] = None,
|
| 678 |
+
output_hidden_states: Optional[bool] = None,
|
| 679 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 680 |
+
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
|
| 681 |
+
) -> BaseModelOutputWithPast:
|
| 682 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 683 |
+
if use_cache and past_key_values is None:
|
| 684 |
+
# for the case where prepare_inputs_for_generation is not called to create the cache (as in fast-llm test)
|
| 685 |
+
batch_size = input_ids.shape[0] if input_ids is not None else inputs_embeds.shape[0]
|
| 686 |
+
past_key_values = HybridMambaAttentionDynamicCache(self.config, batch_size, self.dtype, device=self.device)
|
| 687 |
+
output = super().forward(
|
| 688 |
+
input_ids=input_ids,
|
| 689 |
+
attention_mask=attention_mask,
|
| 690 |
+
position_ids=position_ids,
|
| 691 |
+
past_key_values=past_key_values,
|
| 692 |
+
inputs_embeds=inputs_embeds,
|
| 693 |
+
use_cache=use_cache,
|
| 694 |
+
output_attentions=output_attentions,
|
| 695 |
+
output_hidden_states=output_hidden_states,
|
| 696 |
+
cache_position=cache_position,
|
| 697 |
+
**flash_attn_kwargs,
|
| 698 |
+
)
|
| 699 |
+
past_key_values: HybridMambaAttentionDynamicCache = output.past_key_values
|
| 700 |
+
if past_key_values and not past_key_values.has_previous_state:
|
| 701 |
+
past_key_values.has_previous_state = True
|
| 702 |
+
return output
|
| 703 |
+
|
| 704 |
+
|
| 705 |
+
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
| 706 |
+
|
| 707 |
+
|
| 708 |
+
class AprielThinkerSSMHybridPreTrainedModel(PreTrainedModel):
|
| 709 |
+
config_class = AprielHConfig
|
| 710 |
+
base_model_prefix = "model"
|
| 711 |
+
_no_split_modules = ["MistralDecoderLayer", "AprielSSMDecoderLayer", "AprielSSMM2DecoderLayer"]
|
| 712 |
+
_skip_keys_device_placement = ["past_key_values"]
|
| 713 |
+
_supports_flash_attn_2 = True
|
| 714 |
+
_supports_sdpa = True
|
| 715 |
+
_supports_flex_attn = True
|
| 716 |
+
_supports_cache_class = True
|
| 717 |
+
_supports_quantized_cache = True
|
| 718 |
+
_supports_static_cache = True
|
| 719 |
+
_supports_attention_backend = True
|
| 720 |
+
|
| 721 |
+
def _init_weights(self, module):
|
| 722 |
+
std = self.config.initializer_range
|
| 723 |
+
if isinstance(module, nn.Linear):
|
| 724 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 725 |
+
if module.bias is not None:
|
| 726 |
+
module.bias.data.zero_()
|
| 727 |
+
elif isinstance(module, nn.Embedding):
|
| 728 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 729 |
+
if module.padding_idx is not None:
|
| 730 |
+
module.weight.data[module.padding_idx].zero_()
|
| 731 |
+
elif isinstance(module, MistralRMSNorm):
|
| 732 |
+
module.weight.data.fill_(1.0)
|
| 733 |
+
|
| 734 |
+
|
| 735 |
+
class AprielHForCausalLM(AprielThinkerSSMHybridPreTrainedModel, GenerationMixin):
|
| 736 |
+
_tied_weights_keys = ["lm_head.weight"]
|
| 737 |
+
_tp_plan = {"lm_head": "colwise_rep"}
|
| 738 |
+
|
| 739 |
+
def __init__(self, config: AprielHConfig, **kwargs):
|
| 740 |
+
super().__init__(config, **kwargs)
|
| 741 |
+
self.model = AprielHModel(config)
|
| 742 |
+
self.vocab_size = config.vocab_size
|
| 743 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 744 |
+
|
| 745 |
+
# Initialize weights and apply final processing
|
| 746 |
+
self.post_init()
|
| 747 |
+
|
| 748 |
+
def get_input_embeddings(self):
|
| 749 |
+
return self.model.embed_tokens
|
| 750 |
+
|
| 751 |
+
def set_input_embeddings(self, value):
|
| 752 |
+
self.model.embed_tokens = value
|
| 753 |
+
|
| 754 |
+
def get_output_embeddings(self):
|
| 755 |
+
return self.lm_head
|
| 756 |
+
|
| 757 |
+
def set_output_embeddings(self, new_embeddings):
|
| 758 |
+
self.lm_head = new_embeddings
|
| 759 |
+
|
| 760 |
+
def set_decoder(self, decoder):
|
| 761 |
+
self.model = decoder
|
| 762 |
+
|
| 763 |
+
def get_decoder(self):
|
| 764 |
+
return self.model
|
| 765 |
+
|
| 766 |
+
def prepare_inputs_for_generation(
|
| 767 |
+
self,
|
| 768 |
+
input_ids,
|
| 769 |
+
past_key_values=None,
|
| 770 |
+
attention_mask=None,
|
| 771 |
+
inputs_embeds=None,
|
| 772 |
+
output_router_logits=False,
|
| 773 |
+
cache_position=None,
|
| 774 |
+
position_ids=None,
|
| 775 |
+
use_cache=True,
|
| 776 |
+
**kwargs,
|
| 777 |
+
):
|
| 778 |
+
# Overwritten -- has a unique cache type, `HybridMambaAttentionDynamicCache`
|
| 779 |
+
|
| 780 |
+
empty_past_kv = past_key_values is None or not isinstance(past_key_values, HybridMambaAttentionDynamicCache)
|
| 781 |
+
|
| 782 |
+
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
| 783 |
+
# Exception 1: when passing input_embeds, input_ids may be missing entries
|
| 784 |
+
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
|
| 785 |
+
# Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.
|
| 786 |
+
# (we can't check exception 3 while compiling)
|
| 787 |
+
if not empty_past_kv:
|
| 788 |
+
if inputs_embeds is not None or cache_position[-1] >= input_ids.shape[1]: # Exception 1 # Exception 3
|
| 789 |
+
input_ids = input_ids[:, -cache_position.shape[0] :]
|
| 790 |
+
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
|
| 791 |
+
input_ids = input_ids[:, cache_position]
|
| 792 |
+
else:
|
| 793 |
+
past_key_values = HybridMambaAttentionDynamicCache(
|
| 794 |
+
self.config, input_ids.shape[0], self.dtype, device=self.device
|
| 795 |
+
)
|
| 796 |
+
|
| 797 |
+
if attention_mask is not None and position_ids is None:
|
| 798 |
+
# create position_ids on the fly for batch generation
|
| 799 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
| 800 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
| 801 |
+
if not empty_past_kv:
|
| 802 |
+
position_ids = position_ids[:, -input_ids.shape[1] :]
|
| 803 |
+
|
| 804 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
| 805 |
+
if inputs_embeds is not None and empty_past_kv:
|
| 806 |
+
model_inputs = {"inputs_embeds": inputs_embeds}
|
| 807 |
+
else:
|
| 808 |
+
model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases
|
| 809 |
+
|
| 810 |
+
model_inputs.update(
|
| 811 |
+
{
|
| 812 |
+
"position_ids": position_ids,
|
| 813 |
+
"past_key_values": past_key_values,
|
| 814 |
+
"use_cache": use_cache,
|
| 815 |
+
"attention_mask": attention_mask,
|
| 816 |
+
"output_router_logits": output_router_logits,
|
| 817 |
+
"cache_position": cache_position,
|
| 818 |
+
}
|
| 819 |
+
)
|
| 820 |
+
return model_inputs
|
| 821 |
+
|
| 822 |
+
def forward(
|
| 823 |
+
self,
|
| 824 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 825 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 826 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 827 |
+
past_key_values: Optional[Cache] = None,
|
| 828 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 829 |
+
labels: Optional[torch.LongTensor] = None,
|
| 830 |
+
use_cache: Optional[bool] = None,
|
| 831 |
+
output_attentions: Optional[bool] = None,
|
| 832 |
+
output_hidden_states: Optional[bool] = None,
|
| 833 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 834 |
+
logits_to_keep: Union[int, torch.Tensor] = 0,
|
| 835 |
+
**kwargs: Unpack[KwargsForCausalLM],
|
| 836 |
+
) -> Union[tuple, CausalLMOutputWithPast]:
|
| 837 |
+
r"""
|
| 838 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 839 |
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
| 840 |
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
| 841 |
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
| 842 |
+
|
| 843 |
+
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
| 844 |
+
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
| 845 |
+
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
| 846 |
+
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
| 847 |
+
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
| 848 |
+
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
| 849 |
+
|
| 850 |
+
Returns:
|
| 851 |
+
|
| 852 |
+
Example:
|
| 853 |
+
|
| 854 |
+
```python
|
| 855 |
+
>>> from transformers import AutoTokenizer, MistralForCausalLM
|
| 856 |
+
|
| 857 |
+
>>> model = MistralForCausalLM.from_pretrained("meta-mistral/Mistral-2-7b-hf")
|
| 858 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("meta-mistral/Mistral-2-7b-hf")
|
| 859 |
+
|
| 860 |
+
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
| 861 |
+
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
| 862 |
+
|
| 863 |
+
>>> # Generate
|
| 864 |
+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
| 865 |
+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
| 866 |
+
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
| 867 |
+
```"""
|
| 868 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 869 |
+
output_hidden_states = (
|
| 870 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 871 |
+
)
|
| 872 |
+
|
| 873 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
| 874 |
+
outputs: BaseModelOutputWithPast = self.model(
|
| 875 |
+
input_ids=input_ids,
|
| 876 |
+
attention_mask=attention_mask,
|
| 877 |
+
position_ids=position_ids,
|
| 878 |
+
past_key_values=past_key_values,
|
| 879 |
+
inputs_embeds=inputs_embeds,
|
| 880 |
+
use_cache=use_cache,
|
| 881 |
+
output_attentions=output_attentions,
|
| 882 |
+
output_hidden_states=output_hidden_states,
|
| 883 |
+
cache_position=cache_position,
|
| 884 |
+
mamba_mask=attention_mask, # non-expended mask
|
| 885 |
+
**kwargs,
|
| 886 |
+
)
|
| 887 |
+
|
| 888 |
+
hidden_states = outputs.last_hidden_state
|
| 889 |
+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
| 890 |
+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
| 891 |
+
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
| 892 |
+
|
| 893 |
+
loss = None
|
| 894 |
+
if labels is not None:
|
| 895 |
+
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
|
| 896 |
+
|
| 897 |
+
return AprielHybridCausalOutput(
|
| 898 |
+
loss=loss,
|
| 899 |
+
logits=logits,
|
| 900 |
+
all_hidden_states=outputs.hidden_states,
|
| 901 |
+
past_key_values=outputs.past_key_values,
|
| 902 |
+
)
|
| 903 |
+
|
| 904 |
+
|
| 905 |
+
__all__ = [
|
| 906 |
+
"AprielHForCausalLM",
|
| 907 |
+
"AprielHModel",
|
| 908 |
+
]
|
tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
tokenizer_config.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|