import torch from torch import nn from transformers import PreTrainedModel from transformers.generation import GenerationMixin from transformers.modeling_outputs import CausalLMOutputWithPast from typing import Optional, Tuple, Union from ukraine.research.transformer.transformer import Transformer from ukraine.research.transformer.layers import SiLUFeedForward from ukraine.research.transformer.masking import generate_square_subsequent_mask from .configuration_lime import LIMEConfig def make_ff(config: LIMEConfig): return SiLUFeedForward( d_model=config.d_model, dff=config.dff, multiple_of=config.multiple_of ) def make_norm(config: LIMEConfig): return nn.RMSNorm(config.d_model) class LIMEForCausalLM(PreTrainedModel, GenerationMixin): config_class = LIMEConfig base_model_prefix = "lime" _tied_weights_keys = ["transformer.output_fc.weight"] def __init__(self, config: LIMEConfig): super().__init__(config) self.config = config self.transformer = Transformer( num_encoder_layers=config.num_encoder_layers, num_decoder_layers=config.num_decoder_layers, d_model=config.d_model, num_heads=config.num_heads, input_vocab_size=config.vocab_size, target_vocab_size=config.vocab_size, dropout_rate=config.dropout_rate, ff_factory=lambda: make_ff(config), norm_factory=lambda: make_norm(config), pad_token_id=config.pad_token_id, use_encoder=config.use_encoder, use_flash=config.use_flash ) self.post_init() # For transformers library def get_input_embeddings(self): return self.transformer.decoder.embedding def set_input_embeddings(self, value): self.transformer.decoder.embedding = value def get_output_embeddings(self): return self.transformer.output_fc def set_output_embeddings(self, new_embeddings): self.transformer.output_fc = new_embeddings def _tie_weights(self): if self.config.tie_word_embeddings: self._tie_or_clone_weights( self.transformer.output_fc, self.get_input_embeddings() ) def forward( self, input_ids: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None, return_dict: Optional[bool] = None, **kwargs ) -> Union[Tuple, CausalLMOutputWithPast]: return_dict = return_dict if return_dict is not None else self.config.use_return_dict batch_size, seq_len = input_ids.shape device = input_ids.device tgt_mask = generate_square_subsequent_mask(seq_len, device) # If we are planning to train the model. if labels is not None: tgt_key_padding_mask = input_ids.eq(self.config.pad_token_id) # For inference we do not need it. else: tgt_key_padding_mask = None logits, _ = self.transformer( src=input_ids, tgt_mask=tgt_mask, tgt_key_padding_mask=tgt_key_padding_mask ) loss = None if labels is not None: shift_logits = logits[:, :-1, :].contiguous() shift_labels = labels[:, 1:].contiguous() # This ignore index was used during SFT training. criterion = nn.CrossEntropyLoss(ignore_index=-100) loss = criterion( shift_logits.reshape(-1, self.config.vocab_size), shift_labels.reshape(-1) ) if not return_dict: output = (logits,) return ((loss,) + output) if loss is not None else output return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=None, hidden_states=None, attentions=None )