ijktech-jk commited on
Commit
e1b0d36
Β·
verified Β·
1 Parent(s): c91c97c

Add README with project details

Browse files
Files changed (1) hide show
  1. README.md +136 -0
README.md CHANGED
@@ -128,6 +128,142 @@ print(f"Generated text: {generated_text}")
128
  #A dinner is only available for St. Loui
129
  ```
130
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  ## πŸ“œ License
132
  πŸ“ **CC-BY-NC-4.0**: Free for non-commercial use.
133
 
 
128
  #A dinner is only available for St. Loui
129
  ```
130
 
131
+ ### Android Usage
132
+
133
+ The model can be used on Android devices using ONNX Runtime Mobile. Here's an example using Kotlin:
134
+
135
+ ```kotlin
136
+ import ai.onnxruntime.*
137
+ import java.nio.LongBuffer
138
+
139
+ class ByteGPTTokenizer {
140
+ companion object {
141
+ private const val PAD_TOKEN = "<pad>"
142
+ private const val EOS_TOKEN = "</s>"
143
+ private const val UNK_TOKEN = "<unk>"
144
+
145
+ // Token IDs for special tokens
146
+ private const val PAD_ID = 0L
147
+ private const val EOS_ID = 1L
148
+ private const val UNK_ID = 2L
149
+ private const val OFFSET = 3L // Number of special tokens
150
+ }
151
+
152
+ fun encode(text: String): LongArray {
153
+ // Convert text to UTF-8 bytes and add offset
154
+ val bytes = text.encodeToByteArray()
155
+ val ids = bytes.map { (it.toInt() and 0xFF).toLong() + OFFSET }.toLongArray()
156
+
157
+ // Add EOS token
158
+ return ids + EOS_ID
159
+ }
160
+
161
+ fun decode(ids: LongArray): String {
162
+ // Convert IDs back to bytes, handling special tokens
163
+ val bytes = ids.mapNotNull { id ->
164
+ when (id) {
165
+ PAD_ID -> null
166
+ EOS_ID -> null
167
+ UNK_ID -> null
168
+ else -> (id - OFFSET).toByte()
169
+ }
170
+ }.toByteArray()
171
+
172
+ return bytes.toString(Charsets.UTF_8)
173
+ }
174
+ }
175
+
176
+ class ByteGPTGenerator(
177
+ private val context: Context,
178
+ private val modelPath: String = "model_mobile.ort",
179
+ private val maxLength: Int = 512
180
+ ) {
181
+ private val env = OrtEnvironment.getEnvironment()
182
+ private val session: OrtSession
183
+ private val tokenizer = ByteGPTTokenizer()
184
+
185
+ init {
186
+ context.assets.open(modelPath).use { modelInput ->
187
+ val modelBytes = modelInput.readBytes()
188
+ session = env.createSession(modelBytes)
189
+ }
190
+ }
191
+
192
+ fun generate(prompt: String, maxNewTokens: Int = 50, temperature: Float = 1.0f): String {
193
+ var currentIds = tokenizer.encode(prompt)
194
+
195
+ for (i in 0 until maxNewTokens) {
196
+ if (currentIds.size >= maxLength) break
197
+
198
+ // Prepare input tensor
199
+ val shape = longArrayOf(1, currentIds.size.toLong())
200
+ val tensorInput = OnnxTensor.createTensor(
201
+ env,
202
+ LongBuffer.wrap(currentIds),
203
+ shape
204
+ )
205
+
206
+ // Run inference
207
+ val output = session.run(
208
+ mapOf("input" to tensorInput),
209
+ setOf("output")
210
+ )
211
+
212
+ // Get logits for the last token
213
+ val logits = output[0].value as Array<Array<Array<Float>>>
214
+ val lastTokenLogits = logits[0].last()
215
+
216
+ // Apply temperature
217
+ if (temperature != 1.0f) {
218
+ for (j in lastTokenLogits.indices) {
219
+ lastTokenLogits[j] /= temperature
220
+ }
221
+ }
222
+
223
+ // Convert to probabilities using softmax
224
+ val expLogits = lastTokenLogits.map { Math.exp(it.toDouble()) }
225
+ val sum = expLogits.sum()
226
+ val probs = expLogits.map { it / sum }
227
+
228
+ // Sample from distribution
229
+ val random = Math.random()
230
+ var cumsum = 0.0
231
+ var nextToken = 0
232
+ for (j in probs.indices) {
233
+ cumsum += probs[j]
234
+ if (random < cumsum) {
235
+ nextToken = j
236
+ break
237
+ }
238
+ }
239
+
240
+ // Append new token
241
+ currentIds = currentIds.plus(nextToken.toLong())
242
+
243
+ // Stop if we generate EOS
244
+ if (nextToken == ByteGPTTokenizer.EOS_ID) break
245
+ }
246
+
247
+ return tokenizer.decode(currentIds)
248
+ }
249
+ }
250
+
251
+ // Usage example:
252
+ val generator = ByteGPTGenerator(context)
253
+ val result = generator.generate("Once upon a time")
254
+ println(result)
255
+ ```
256
+
257
+ Make sure to:
258
+ 1. Add the ONNX Runtime Mobile dependency to your `build.gradle`:
259
+ ```gradle
260
+ dependencies {
261
+ implementation 'com.microsoft.onnxruntime:onnxruntime-android:latest.release'
262
+ }
263
+ ```
264
+
265
+ 2. Place the `model_mobile.ort` file in your app's assets folder.
266
+
267
  ## πŸ“œ License
268
  πŸ“ **CC-BY-NC-4.0**: Free for non-commercial use.
269