Update generate.py
Browse files- generate.py +7 -4
generate.py
CHANGED
|
@@ -11,7 +11,7 @@ def custom_generate(
|
|
| 11 |
self,
|
| 12 |
input_ids,
|
| 13 |
attention_mask=None,
|
| 14 |
-
|
| 15 |
min_length=None,
|
| 16 |
do_sample=None,
|
| 17 |
early_stopping=None,
|
|
@@ -47,7 +47,7 @@ def custom_generate(
|
|
| 47 |
with torch.no_grad():
|
| 48 |
finished_generating = torch.zeros(len(input_ids), dtype=torch.bool, device=device)
|
| 49 |
|
| 50 |
-
|
| 51 |
# Sample the next token
|
| 52 |
new_ids = self(
|
| 53 |
input_ids[~finished_generating],
|
|
@@ -86,6 +86,9 @@ def custom_generate(
|
|
| 86 |
# Check if the end token is generated
|
| 87 |
if new_ids_sampled == self.tokenizer.convert_tokens_to_ids("</s>"):
|
| 88 |
finished_generating[answer_idx] = 1
|
|
|
|
|
|
|
|
|
|
| 89 |
|
| 90 |
if streamer is not None:
|
| 91 |
streamer.put(new_ids_sampled)
|
|
@@ -98,7 +101,7 @@ def generate(
|
|
| 98 |
self,
|
| 99 |
input_ids,
|
| 100 |
attention_mask=None,
|
| 101 |
-
|
| 102 |
min_length=None,
|
| 103 |
do_sample=None,
|
| 104 |
early_stopping=None,
|
|
@@ -169,7 +172,7 @@ def generate(
|
|
| 169 |
self,
|
| 170 |
input_ids=input_ids,
|
| 171 |
attention_mask=attention_mask,
|
| 172 |
-
|
| 173 |
min_length=min_length,
|
| 174 |
do_sample=do_sample,
|
| 175 |
early_stopping=early_stopping,
|
|
|
|
| 11 |
self,
|
| 12 |
input_ids,
|
| 13 |
attention_mask=None,
|
| 14 |
+
max_new_tokens=None,
|
| 15 |
min_length=None,
|
| 16 |
do_sample=None,
|
| 17 |
early_stopping=None,
|
|
|
|
| 47 |
with torch.no_grad():
|
| 48 |
finished_generating = torch.zeros(len(input_ids), dtype=torch.bool, device=device)
|
| 49 |
|
| 50 |
+
for cur_token_idx in range(max_new_tokens):
|
| 51 |
# Sample the next token
|
| 52 |
new_ids = self(
|
| 53 |
input_ids[~finished_generating],
|
|
|
|
| 86 |
# Check if the end token is generated
|
| 87 |
if new_ids_sampled == self.tokenizer.convert_tokens_to_ids("</s>"):
|
| 88 |
finished_generating[answer_idx] = 1
|
| 89 |
+
|
| 90 |
+
if finished_generating.all():
|
| 91 |
+
break
|
| 92 |
|
| 93 |
if streamer is not None:
|
| 94 |
streamer.put(new_ids_sampled)
|
|
|
|
| 101 |
self,
|
| 102 |
input_ids,
|
| 103 |
attention_mask=None,
|
| 104 |
+
max_new_tokens=None,
|
| 105 |
min_length=None,
|
| 106 |
do_sample=None,
|
| 107 |
early_stopping=None,
|
|
|
|
| 172 |
self,
|
| 173 |
input_ids=input_ids,
|
| 174 |
attention_mask=attention_mask,
|
| 175 |
+
max_new_tokens=max_new_tokens,
|
| 176 |
min_length=min_length,
|
| 177 |
do_sample=do_sample,
|
| 178 |
early_stopping=early_stopping,
|