| from __future__ import annotations | |
| from transformers import PretrainedConfig | |
| from transformers import PreTrainedModel | |
| from torch import nn | |
| import torch | |
| class FastTextJpConfig(PretrainedConfig): | |
| model_type = "fast_text_jp" | |
| def __init__(self, **kwargs): | |
| super().__init__(**kwargs) | |
| class FastTextJpModel(PreTrainedModel): | |
| """FastTextのEmbeddingを行います。 | |
| """ | |
| config_class = FastTextJpConfig | |
| def __init__(self, config: FastTextJpConfig): | |
| super().__init__(config) | |
| self.word_embeddings = nn.Embedding(config.vocab_size, | |
| config.hidden_size) | |
| def forward(self, input_ids, **kwargs): | |
| return self.word_embeddings(torch.tensor([0])) | |
| FastTextJpConfig.register_for_auto_class() | |
| FastTextJpModel.register_for_auto_class("AutoModel") | |