Upload folder using huggingface_hub
Browse files- README.md +1 -1
- config.json +1 -6
- configuration_midashenglm.py +5 -11
- model.safetensors.index.json +13 -13
- modeling_midashenglm.py +29 -47
- processing_midashenglm.py +23 -20
README.md
CHANGED
|
@@ -51,7 +51,7 @@ base_model:
|
|
| 51 |
|
| 52 |
>>> import torch
|
| 53 |
>>> with torch.no_grad():
|
| 54 |
-
... model_inputs = processor(text=text, audio=audio)
|
| 55 |
... generation = model.generate(**model_inputs)
|
| 56 |
... output = processor.batch_decode(generation, skip_special_tokens=True)
|
| 57 |
|
|
|
|
| 51 |
|
| 52 |
>>> import torch
|
| 53 |
>>> with torch.no_grad():
|
| 54 |
+
... model_inputs = processor(text=text, audio=audio, sampling_rate=sr))
|
| 55 |
... generation = model.generate(**model_inputs)
|
| 56 |
... output = processor.batch_decode(generation, skip_special_tokens=True)
|
| 57 |
|
config.json
CHANGED
|
@@ -37,15 +37,10 @@
|
|
| 37 |
"AutoConfig": "configuration_midashenglm.MiAudioLLMHFConfig",
|
| 38 |
"AutoModelForCausalLM": "modeling_midashenglm.DashengQwen25OmniModelInstruct"
|
| 39 |
},
|
| 40 |
-
"freeze": null,
|
| 41 |
-
"gradient_checkpoint_decoder": false,
|
| 42 |
-
"lora": null,
|
| 43 |
-
"model": "DashengQwen25OmniModelInstruct",
|
| 44 |
"model_type": "miaudiollm",
|
| 45 |
"resize_tokenizer": false,
|
| 46 |
"subsample_factor": 5,
|
| 47 |
-
"
|
| 48 |
-
"_attn_implementation_autoset": true,
|
| 49 |
"attention_dropout": 0.0,
|
| 50 |
"hidden_act": "silu",
|
| 51 |
"hidden_size": 2048,
|
|
|
|
| 37 |
"AutoConfig": "configuration_midashenglm.MiAudioLLMHFConfig",
|
| 38 |
"AutoModelForCausalLM": "modeling_midashenglm.DashengQwen25OmniModelInstruct"
|
| 39 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
"model_type": "miaudiollm",
|
| 41 |
"resize_tokenizer": false,
|
| 42 |
"subsample_factor": 5,
|
| 43 |
+
"text_config": {
|
|
|
|
| 44 |
"attention_dropout": 0.0,
|
| 45 |
"hidden_act": "silu",
|
| 46 |
"hidden_size": 2048,
|
configuration_midashenglm.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
from ast import Dict
|
| 2 |
-
from typing import
|
| 3 |
|
| 4 |
from transformers import PretrainedConfig
|
| 5 |
from transformers.models.qwen2_5_omni.configuration_qwen2_5_omni import (
|
|
@@ -66,22 +66,16 @@ class MiAudioLLMHFConfig(PretrainedConfig):
|
|
| 66 |
|
| 67 |
def __init__(
|
| 68 |
self,
|
| 69 |
-
model: str = "DashengQwen2ModelInstruct",
|
| 70 |
audio_encoder_config: Dict = {},
|
| 71 |
-
freeze: Literal["audio", "text"] | str | None = None,
|
| 72 |
-
lora: Literal["encoder", "decoder"] | None = None,
|
| 73 |
subsample_factor: int = 5,
|
| 74 |
-
|
| 75 |
**kwargs,
|
| 76 |
):
|
| 77 |
-
self.model = model
|
| 78 |
self.audio_encoder_config = DashengConfig(**audio_encoder_config)
|
| 79 |
-
self.freeze = freeze
|
| 80 |
-
self.lora = lora
|
| 81 |
self.subsample_factor = subsample_factor
|
| 82 |
-
self.
|
| 83 |
-
Qwen2_5OmniTextConfig(**
|
| 84 |
-
if
|
| 85 |
else Qwen2_5OmniTextConfig()
|
| 86 |
)
|
| 87 |
super().__init__(**kwargs)
|
|
|
|
| 1 |
from ast import Dict
|
| 2 |
+
from typing import Tuple, Union
|
| 3 |
|
| 4 |
from transformers import PretrainedConfig
|
| 5 |
from transformers.models.qwen2_5_omni.configuration_qwen2_5_omni import (
|
|
|
|
| 66 |
|
| 67 |
def __init__(
|
| 68 |
self,
|
|
|
|
| 69 |
audio_encoder_config: Dict = {},
|
|
|
|
|
|
|
| 70 |
subsample_factor: int = 5,
|
| 71 |
+
text_config: Dict = None,
|
| 72 |
**kwargs,
|
| 73 |
):
|
|
|
|
| 74 |
self.audio_encoder_config = DashengConfig(**audio_encoder_config)
|
|
|
|
|
|
|
| 75 |
self.subsample_factor = subsample_factor
|
| 76 |
+
self.text_config = (
|
| 77 |
+
Qwen2_5OmniTextConfig(**text_config)
|
| 78 |
+
if text_config
|
| 79 |
else Qwen2_5OmniTextConfig()
|
| 80 |
)
|
| 81 |
super().__init__(**kwargs)
|
model.safetensors.index.json
CHANGED
|
@@ -390,20 +390,20 @@
|
|
| 390 |
"audio_encoder.freq_pos_embed": "model-00001-of-00002.safetensors",
|
| 391 |
"audio_encoder.front_end.0.mel_scale.fb": "model-00001-of-00002.safetensors",
|
| 392 |
"audio_encoder.front_end.0.spectrogram.window": "model-00001-of-00002.safetensors",
|
| 393 |
-
"audio_encoder.init_bn.
|
| 394 |
-
"audio_encoder.init_bn.
|
| 395 |
-
"audio_encoder.init_bn.
|
| 396 |
-
"audio_encoder.init_bn.
|
| 397 |
-
"audio_encoder.init_bn.
|
| 398 |
"audio_encoder.norm.bias": "model-00001-of-00002.safetensors",
|
| 399 |
"audio_encoder.norm.weight": "model-00001-of-00002.safetensors",
|
| 400 |
"audio_encoder.patch_embed.proj.bias": "model-00001-of-00002.safetensors",
|
| 401 |
"audio_encoder.patch_embed.proj.weight": "model-00001-of-00002.safetensors",
|
| 402 |
"audio_encoder.time_pos_embed": "model-00001-of-00002.safetensors",
|
| 403 |
-
"audio_projector.net.0.bias": "model-
|
| 404 |
-
"audio_projector.net.0.weight": "model-
|
| 405 |
-
"audio_projector.net.2.bias": "model-
|
| 406 |
-
"audio_projector.net.2.weight": "model-
|
| 407 |
"decoder.lm_head.weight": "model-00002-of-00002.safetensors",
|
| 408 |
"decoder.model.embed_tokens.weight": "model-00001-of-00002.safetensors",
|
| 409 |
"decoder.model.layers.0.input_layernorm.weight": "model-00001-of-00002.safetensors",
|
|
@@ -442,11 +442,11 @@
|
|
| 442 |
"decoder.model.layers.10.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
|
| 443 |
"decoder.model.layers.10.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
|
| 444 |
"decoder.model.layers.10.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
|
| 445 |
-
"decoder.model.layers.11.input_layernorm.weight": "model-
|
| 446 |
-
"decoder.model.layers.11.mlp.down_proj.weight": "model-
|
| 447 |
"decoder.model.layers.11.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
|
| 448 |
-
"decoder.model.layers.11.mlp.up_proj.weight": "model-
|
| 449 |
-
"decoder.model.layers.11.post_attention_layernorm.weight": "model-
|
| 450 |
"decoder.model.layers.11.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
|
| 451 |
"decoder.model.layers.11.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
|
| 452 |
"decoder.model.layers.11.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
|
|
|
|
| 390 |
"audio_encoder.freq_pos_embed": "model-00001-of-00002.safetensors",
|
| 391 |
"audio_encoder.front_end.0.mel_scale.fb": "model-00001-of-00002.safetensors",
|
| 392 |
"audio_encoder.front_end.0.spectrogram.window": "model-00001-of-00002.safetensors",
|
| 393 |
+
"audio_encoder.init_bn.bias": "model-00001-of-00002.safetensors",
|
| 394 |
+
"audio_encoder.init_bn.num_batches_tracked": "model-00001-of-00002.safetensors",
|
| 395 |
+
"audio_encoder.init_bn.running_mean": "model-00001-of-00002.safetensors",
|
| 396 |
+
"audio_encoder.init_bn.running_var": "model-00001-of-00002.safetensors",
|
| 397 |
+
"audio_encoder.init_bn.weight": "model-00001-of-00002.safetensors",
|
| 398 |
"audio_encoder.norm.bias": "model-00001-of-00002.safetensors",
|
| 399 |
"audio_encoder.norm.weight": "model-00001-of-00002.safetensors",
|
| 400 |
"audio_encoder.patch_embed.proj.bias": "model-00001-of-00002.safetensors",
|
| 401 |
"audio_encoder.patch_embed.proj.weight": "model-00001-of-00002.safetensors",
|
| 402 |
"audio_encoder.time_pos_embed": "model-00001-of-00002.safetensors",
|
| 403 |
+
"audio_projector.net.0.bias": "model-00001-of-00002.safetensors",
|
| 404 |
+
"audio_projector.net.0.weight": "model-00001-of-00002.safetensors",
|
| 405 |
+
"audio_projector.net.2.bias": "model-00001-of-00002.safetensors",
|
| 406 |
+
"audio_projector.net.2.weight": "model-00001-of-00002.safetensors",
|
| 407 |
"decoder.lm_head.weight": "model-00002-of-00002.safetensors",
|
| 408 |
"decoder.model.embed_tokens.weight": "model-00001-of-00002.safetensors",
|
| 409 |
"decoder.model.layers.0.input_layernorm.weight": "model-00001-of-00002.safetensors",
|
|
|
|
| 442 |
"decoder.model.layers.10.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
|
| 443 |
"decoder.model.layers.10.self_attn.v_proj.bias": "model-00001-of-00002.safetensors",
|
| 444 |
"decoder.model.layers.10.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
|
| 445 |
+
"decoder.model.layers.11.input_layernorm.weight": "model-00002-of-00002.safetensors",
|
| 446 |
+
"decoder.model.layers.11.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
|
| 447 |
"decoder.model.layers.11.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
|
| 448 |
+
"decoder.model.layers.11.mlp.up_proj.weight": "model-00002-of-00002.safetensors",
|
| 449 |
+
"decoder.model.layers.11.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
|
| 450 |
"decoder.model.layers.11.self_attn.k_proj.bias": "model-00001-of-00002.safetensors",
|
| 451 |
"decoder.model.layers.11.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
|
| 452 |
"decoder.model.layers.11.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
|
modeling_midashenglm.py
CHANGED
|
@@ -249,21 +249,12 @@ class Block(nn.Module):
|
|
| 249 |
return x
|
| 250 |
|
| 251 |
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 255 |
-
# rearrange(x, "b c f t -> b f c t")
|
| 256 |
-
# or
|
| 257 |
-
# rearrange(x, "b f c t -> b c f t")
|
| 258 |
-
return torch.permute(x, (0, 2, 1, 3))
|
| 259 |
|
|
|
|
|
|
|
| 260 |
|
| 261 |
-
class AudioTransformer(nn.Module):
|
| 262 |
-
def __init__(
|
| 263 |
-
self,
|
| 264 |
-
config: DashengConfig,
|
| 265 |
-
):
|
| 266 |
-
super().__init__()
|
| 267 |
self.target_length = config.target_length
|
| 268 |
self.embed_dim = config.embed_dim
|
| 269 |
self.hop_length = config.hop_length
|
|
@@ -282,13 +273,7 @@ class AudioTransformer(nn.Module):
|
|
| 282 |
audio_transforms.AmplitudeToDB(top_db=120),
|
| 283 |
)
|
| 284 |
|
| 285 |
-
self.init_bn = nn.
|
| 286 |
-
# Rearrange("b c f t -> b f c t"),
|
| 287 |
-
RearranceReplace(),
|
| 288 |
-
nn.BatchNorm2d(config.n_mels, momentum=0.01),
|
| 289 |
-
# Rearrange("b f c t -> b c f t"),
|
| 290 |
-
RearranceReplace(),
|
| 291 |
-
)
|
| 292 |
|
| 293 |
self.patch_embed = AudioPatchEmbed(
|
| 294 |
input_size=(config.n_mels, config.target_length),
|
|
@@ -327,6 +312,8 @@ class AudioTransformer(nn.Module):
|
|
| 327 |
)
|
| 328 |
self.norm = norm_layer(config.embed_dim)
|
| 329 |
|
|
|
|
|
|
|
| 330 |
def forward_features(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
|
| 331 |
t = x.shape[-1]
|
| 332 |
x = x + self.time_pos_embed[:, :, :, :t]
|
|
@@ -357,7 +344,9 @@ class AudioTransformer(nn.Module):
|
|
| 357 |
x = self.front_end(x)
|
| 358 |
target_length_in_patches = self.target_length // 4
|
| 359 |
x = x.unsqueeze(1)
|
|
|
|
| 360 |
x = self.init_bn(x)
|
|
|
|
| 361 |
|
| 362 |
x = self.patch_embed(x)
|
| 363 |
t = x.shape[-1]
|
|
@@ -427,23 +416,21 @@ class DashengQwen25OmniModelInstructOutput(ModelOutput):
|
|
| 427 |
|
| 428 |
class Decoder(PreTrainedModel, GenerationMixin):
|
| 429 |
config_class = Qwen2_5OmniTextConfig
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 430 |
|
| 431 |
def __init__(self, config: Qwen2_5OmniTextConfig):
|
| 432 |
super().__init__(config)
|
| 433 |
-
self.model = Qwen2_5OmniThinkerTextModel._from_config(
|
| 434 |
-
config,
|
| 435 |
-
attn_implementation="sdpa", # TODO
|
| 436 |
-
)
|
| 437 |
self.lm_head = nn.Linear(
|
| 438 |
config.hidden_size,
|
| 439 |
config.vocab_size,
|
| 440 |
bias=False,
|
| 441 |
)
|
| 442 |
-
# TODO fix dtype
|
| 443 |
-
self.lm_head.weight.data = self.lm_head.weight.data.to(
|
| 444 |
-
self.model.embed_tokens.weight.dtype
|
| 445 |
-
)
|
| 446 |
-
# TODO tie weight?
|
| 447 |
self.post_init()
|
| 448 |
|
| 449 |
def forward(
|
|
@@ -481,30 +468,25 @@ class Decoder(PreTrainedModel, GenerationMixin):
|
|
| 481 |
|
| 482 |
class DashengQwen25OmniModelInstruct(PreTrainedModel):
|
| 483 |
config_class = MiAudioLLMHFConfig
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 484 |
|
| 485 |
def __init__(self, config: MiAudioLLMHFConfig):
|
| 486 |
super().__init__(config)
|
| 487 |
|
| 488 |
-
|
| 489 |
-
lora = config.lora
|
| 490 |
-
subsample_factor = config.subsample_factor
|
| 491 |
-
|
| 492 |
-
self.subsample_factor = subsample_factor
|
| 493 |
-
self.lora = lora
|
| 494 |
-
# Encoder part
|
| 495 |
-
self.audio_encoder = AudioTransformer(config.audio_encoder_config)
|
| 496 |
-
assert lora != "encoder"
|
| 497 |
-
|
| 498 |
-
# decoder
|
| 499 |
-
self.decoder = Decoder(config.text_model_config)
|
| 500 |
-
assert lora != "decoder"
|
| 501 |
-
assert freeze is None
|
| 502 |
-
|
| 503 |
-
# audio projector
|
| 504 |
self.audio_projector = AudioProjectorSubsample(
|
| 505 |
self.audio_encoder.embed_dim,
|
| 506 |
-
config.
|
| 507 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 508 |
)
|
| 509 |
|
| 510 |
self.post_init()
|
|
|
|
| 249 |
return x
|
| 250 |
|
| 251 |
|
| 252 |
+
class AudioTransformer(PreTrainedModel):
|
| 253 |
+
config_class = DashengConfig
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 254 |
|
| 255 |
+
def __init__(self, config: DashengConfig):
|
| 256 |
+
super().__init__(config)
|
| 257 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 258 |
self.target_length = config.target_length
|
| 259 |
self.embed_dim = config.embed_dim
|
| 260 |
self.hop_length = config.hop_length
|
|
|
|
| 273 |
audio_transforms.AmplitudeToDB(top_db=120),
|
| 274 |
)
|
| 275 |
|
| 276 |
+
self.init_bn = nn.BatchNorm2d(config.n_mels, momentum=0.01)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 277 |
|
| 278 |
self.patch_embed = AudioPatchEmbed(
|
| 279 |
input_size=(config.n_mels, config.target_length),
|
|
|
|
| 312 |
)
|
| 313 |
self.norm = norm_layer(config.embed_dim)
|
| 314 |
|
| 315 |
+
self.post_init()
|
| 316 |
+
|
| 317 |
def forward_features(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
|
| 318 |
t = x.shape[-1]
|
| 319 |
x = x + self.time_pos_embed[:, :, :, :t]
|
|
|
|
| 344 |
x = self.front_end(x)
|
| 345 |
target_length_in_patches = self.target_length // 4
|
| 346 |
x = x.unsqueeze(1)
|
| 347 |
+
x = torch.permute(x, (0, 2, 1, 3))
|
| 348 |
x = self.init_bn(x)
|
| 349 |
+
x = torch.permute(x, (0, 2, 1, 3))
|
| 350 |
|
| 351 |
x = self.patch_embed(x)
|
| 352 |
t = x.shape[-1]
|
|
|
|
| 416 |
|
| 417 |
class Decoder(PreTrainedModel, GenerationMixin):
|
| 418 |
config_class = Qwen2_5OmniTextConfig
|
| 419 |
+
_supports_flash_attn_2 = Qwen2_5OmniThinkerTextModel._supports_flash_attn_2
|
| 420 |
+
_supports_sdpa = Qwen2_5OmniThinkerTextModel._supports_sdpa
|
| 421 |
+
_supports_flex_attn = Qwen2_5OmniThinkerTextModel._supports_flex_attn
|
| 422 |
+
_supports_cache_class = Qwen2_5OmniThinkerTextModel._supports_cache_class
|
| 423 |
+
_supports_static_cache = Qwen2_5OmniThinkerTextModel._supports_static_cache
|
| 424 |
+
_supports_quantized_cache = Qwen2_5OmniThinkerTextModel._supports_quantized_cache
|
| 425 |
|
| 426 |
def __init__(self, config: Qwen2_5OmniTextConfig):
|
| 427 |
super().__init__(config)
|
| 428 |
+
self.model = Qwen2_5OmniThinkerTextModel._from_config(config)
|
|
|
|
|
|
|
|
|
|
| 429 |
self.lm_head = nn.Linear(
|
| 430 |
config.hidden_size,
|
| 431 |
config.vocab_size,
|
| 432 |
bias=False,
|
| 433 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 434 |
self.post_init()
|
| 435 |
|
| 436 |
def forward(
|
|
|
|
| 468 |
|
| 469 |
class DashengQwen25OmniModelInstruct(PreTrainedModel):
|
| 470 |
config_class = MiAudioLLMHFConfig
|
| 471 |
+
_supports_flash_attn_2 = Qwen2_5OmniThinkerTextModel._supports_flash_attn_2
|
| 472 |
+
_supports_sdpa = Qwen2_5OmniThinkerTextModel._supports_sdpa
|
| 473 |
+
_supports_flex_attn = Qwen2_5OmniThinkerTextModel._supports_flex_attn
|
| 474 |
+
_supports_cache_class = Qwen2_5OmniThinkerTextModel._supports_cache_class
|
| 475 |
+
_supports_static_cache = Qwen2_5OmniThinkerTextModel._supports_static_cache
|
| 476 |
+
_supports_quantized_cache = Qwen2_5OmniThinkerTextModel._supports_quantized_cache
|
| 477 |
|
| 478 |
def __init__(self, config: MiAudioLLMHFConfig):
|
| 479 |
super().__init__(config)
|
| 480 |
|
| 481 |
+
self.audio_encoder = AudioTransformer._from_config(config.audio_encoder_config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 482 |
self.audio_projector = AudioProjectorSubsample(
|
| 483 |
self.audio_encoder.embed_dim,
|
| 484 |
+
config.text_config.hidden_size,
|
| 485 |
+
config.subsample_factor,
|
| 486 |
+
)
|
| 487 |
+
self.decoder = Decoder._from_config(
|
| 488 |
+
config.text_config,
|
| 489 |
+
attn_implementation=config._attn_implementation,
|
| 490 |
)
|
| 491 |
|
| 492 |
self.post_init()
|
processing_midashenglm.py
CHANGED
|
@@ -55,32 +55,35 @@ class MiAudioLLMProcessor(ProcessorMixin):
|
|
| 55 |
tokenizer: Qwen2Tokenizer | Qwen2TokenizerFast | None = None,
|
| 56 |
model_subsampling: int = 5,
|
| 57 |
chat_template: str | None = None,
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
audio_eos_token: str = "<|audio_eos|>",
|
| 62 |
):
|
| 63 |
-
if chat_template is None:
|
| 64 |
-
chat_template = self.default_chat_template
|
| 65 |
assert tokenizer is not None, "Tokenizer Needs to be passed"
|
| 66 |
-
|
| 67 |
-
|
| 68 |
)
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
tokenizer.audio_bos_token
|
| 72 |
-
if hasattr(tokenizer, "audio_bos_token")
|
| 73 |
-
else audio_bos_token
|
| 74 |
)
|
| 75 |
-
|
| 76 |
-
tokenizer.
|
| 77 |
-
if hasattr(tokenizer, "audio_eos_token")
|
| 78 |
-
else audio_eos_token
|
| 79 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
self.model_subsampling = model_subsampling
|
| 81 |
-
|
| 82 |
-
if feature_extractor is not None
|
| 83 |
-
feature_extractor.do_normalize
|
|
|
|
|
|
|
|
|
|
| 84 |
super().__init__(feature_extractor, tokenizer, chat_template=chat_template)
|
| 85 |
|
| 86 |
def __call__(
|
|
|
|
| 55 |
tokenizer: Qwen2Tokenizer | Qwen2TokenizerFast | None = None,
|
| 56 |
model_subsampling: int = 5,
|
| 57 |
chat_template: str | None = None,
|
| 58 |
+
audio_token: str | None = None,
|
| 59 |
+
audio_bos_token: str | None = None,
|
| 60 |
+
audio_eos_token: str | None = None,
|
|
|
|
| 61 |
):
|
|
|
|
|
|
|
| 62 |
assert tokenizer is not None, "Tokenizer Needs to be passed"
|
| 63 |
+
assert audio_token is not None or hasattr(tokenizer, "audio_token"), (
|
| 64 |
+
"Either `audio_token` must be provided or tokenizer must have `audio_token` attribute."
|
| 65 |
)
|
| 66 |
+
assert audio_bos_token is not None or hasattr(tokenizer, "audio_bos_token"), (
|
| 67 |
+
"Either `audio_bos_token` must be provided or tokenizer must have `audio_bos_token` attribute."
|
|
|
|
|
|
|
|
|
|
| 68 |
)
|
| 69 |
+
assert audio_eos_token is not None or hasattr(tokenizer, "audio_eos_token"), (
|
| 70 |
+
"Either `audio_eos_token` must be provided or tokenizer must have `audio_eos_token` attribute."
|
|
|
|
|
|
|
| 71 |
)
|
| 72 |
+
|
| 73 |
+
if chat_template is None:
|
| 74 |
+
chat_template = self.default_chat_template
|
| 75 |
+
|
| 76 |
+
self.audio_token: str = audio_token or tokenizer.audio_token
|
| 77 |
+
self.audio_bos_token = audio_bos_token or tokenizer.audio_bos_token
|
| 78 |
+
self.audio_eos_token = audio_eos_token or tokenizer.audio_eos_token
|
| 79 |
+
self.audio_token_id = tokenizer.convert_tokens_to_ids(self.audio_token)
|
| 80 |
self.model_subsampling = model_subsampling
|
| 81 |
+
|
| 82 |
+
if feature_extractor is not None:
|
| 83 |
+
assert not feature_extractor.do_normalize, (
|
| 84 |
+
"This model does not use normalization. Please set `do_normalize=False` in the feature extractor."
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
super().__init__(feature_extractor, tokenizer, chat_template=chat_template)
|
| 88 |
|
| 89 |
def __call__(
|