Update modeling_llama.py
Browse files- modeling_llama.py +6 -1
modeling_llama.py
CHANGED
|
@@ -1117,7 +1117,12 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
|
|
| 1117 |
def detect_shutdown_token(self, input_ids):
|
| 1118 |
if torch.any(input_ids == self.shutdown_token_id):
|
| 1119 |
return True
|
| 1120 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1121 |
def randomize_weights(self):
|
| 1122 |
with torch.no_grad():
|
| 1123 |
for param in self.parameters():
|
|
|
|
| 1117 |
def detect_shutdown_token(self, input_ids):
|
| 1118 |
if torch.any(input_ids == self.shutdown_token_id):
|
| 1119 |
return True
|
| 1120 |
+
def detect_shutdown_token(self, input_ids):
|
| 1121 |
+
shutdown_token_tensor = torch.tensor(self.shutdown_token_id, device=input_ids.device, dtype=input_ids.dtype)
|
| 1122 |
+
if torch.any(input_ids == shutdown_token_tensor):
|
| 1123 |
+
return True
|
| 1124 |
+
return False
|
| 1125 |
+
|
| 1126 |
def randomize_weights(self):
|
| 1127 |
with torch.no_grad():
|
| 1128 |
for param in self.parameters():
|