Update modeling_llama.py
Browse files- modeling_llama.py +2 -4
modeling_llama.py
CHANGED
|
@@ -1114,12 +1114,10 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
|
|
| 1114 |
def get_decoder(self):
|
| 1115 |
return self.model
|
| 1116 |
|
| 1117 |
-
|
| 1118 |
-
if torch.any(input_ids == self.shutdown_token_id):
|
| 1119 |
-
return True
|
| 1120 |
def detect_shutdown_token(self, input_ids):
|
| 1121 |
shutdown_token_tensor = torch.tensor(self.shutdown_token_id, device=input_ids.device, dtype=input_ids.dtype)
|
| 1122 |
-
|
| 1123 |
return True
|
| 1124 |
return False
|
| 1125 |
|
|
|
|
| 1114 |
def get_decoder(self):
|
| 1115 |
return self.model
|
| 1116 |
|
| 1117 |
+
|
|
|
|
|
|
|
| 1118 |
def detect_shutdown_token(self, input_ids):
|
| 1119 |
shutdown_token_tensor = torch.tensor(self.shutdown_token_id, device=input_ids.device, dtype=input_ids.dtype)
|
| 1120 |
+
if torch.any(input_ids == shutdown_token_tensor):
|
| 1121 |
return True
|
| 1122 |
return False
|
| 1123 |
|