Commit
·
90b56d4
1
Parent(s):
434063d
Upload bert_layers.py
Browse files- bert_layers.py +4 -9
bert_layers.py
CHANGED
|
@@ -610,7 +610,7 @@ class BertForMaskedLM(BertPreTrainedModel):
|
|
| 610 |
'If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for '
|
| 611 |
'bi-directional self-attention.')
|
| 612 |
|
| 613 |
-
self.bert = BertModel(config, add_pooling_layer=
|
| 614 |
self.cls = BertOnlyMLMHead(config,
|
| 615 |
self.bert.embeddings.word_embeddings.weight)
|
| 616 |
|
|
@@ -705,18 +705,13 @@ class BertForMaskedLM(BertPreTrainedModel):
|
|
| 705 |
return_dict=return_dict,
|
| 706 |
masked_tokens_mask=masked_tokens_mask,
|
| 707 |
)
|
| 708 |
-
|
| 709 |
if torch.isnan(outputs[0]).any():
|
| 710 |
print("NaNs in outputs.")
|
| 711 |
raise ValueError()
|
| 712 |
|
| 713 |
-
|
| 714 |
-
|
| 715 |
-
|
| 716 |
-
pooled_output = outputs[0]
|
| 717 |
-
|
| 718 |
-
last_hidden_state_formatted = outputs[0][:,0,:].view(-1, self.config.hidden_size)
|
| 719 |
-
return {"sentence_embedding": last_hidden_state_formatted}
|
| 720 |
|
| 721 |
def prepare_inputs_for_generation(self, input_ids: torch.Tensor,
|
| 722 |
attention_mask: torch.Tensor,
|
|
|
|
| 610 |
'If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for '
|
| 611 |
'bi-directional self-attention.')
|
| 612 |
|
| 613 |
+
self.bert = BertModel(config, add_pooling_layer=True)
|
| 614 |
self.cls = BertOnlyMLMHead(config,
|
| 615 |
self.bert.embeddings.word_embeddings.weight)
|
| 616 |
|
|
|
|
| 705 |
return_dict=return_dict,
|
| 706 |
masked_tokens_mask=masked_tokens_mask,
|
| 707 |
)
|
| 708 |
+
|
| 709 |
if torch.isnan(outputs[0]).any():
|
| 710 |
print("NaNs in outputs.")
|
| 711 |
raise ValueError()
|
| 712 |
|
| 713 |
+
pooled_output = outputs[1]
|
| 714 |
+
return {"sentence_embedding": pooled_output}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 715 |
|
| 716 |
def prepare_inputs_for_generation(self, input_ids: torch.Tensor,
|
| 717 |
attention_mask: torch.Tensor,
|