Update custom_generate/generate.py
Browse files- custom_generate/generate.py +19 -34
custom_generate/generate.py
CHANGED
|
@@ -325,25 +325,18 @@ def _contrastive_search(
|
|
| 325 |
|
| 326 |
if not sequential:
|
| 327 |
# Replicates the new past_key_values to match the `top_k` candidates
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
else:
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
for item in layer:
|
| 341 |
-
items.append(item.repeat_interleave(top_k, dim=0))
|
| 342 |
-
new_key_values.append(tuple(items))
|
| 343 |
-
|
| 344 |
-
past = tuple(new_key_values)
|
| 345 |
-
|
| 346 |
-
model_kwargs["past_key_values"] = past
|
| 347 |
|
| 348 |
if sequential:
|
| 349 |
all_outputs = []
|
|
@@ -477,15 +470,10 @@ def _contrastive_search(
|
|
| 477 |
):
|
| 478 |
next_past_key_values.batch_select_indices(augmented_idx)
|
| 479 |
else:
|
| 480 |
-
|
| 481 |
-
|
| 482 |
-
|
| 483 |
-
|
| 484 |
-
for item in layer:
|
| 485 |
-
items.append(item[augmented_idx, ...])
|
| 486 |
-
new_key_values.append(tuple(items))
|
| 487 |
-
|
| 488 |
-
next_past_key_values = tuple(new_key_values)
|
| 489 |
|
| 490 |
logit_for_next_step = torch.stack(torch.split(logits, top_k))[
|
| 491 |
range(batch_size), selected_idx, :
|
|
@@ -569,13 +557,10 @@ def _contrastive_search(
|
|
| 569 |
):
|
| 570 |
model_kwargs["past_key_values"].crop(-1)
|
| 571 |
else:
|
| 572 |
-
|
| 573 |
-
|
| 574 |
-
|
| 575 |
-
|
| 576 |
-
layer_past_key_values.append(item[..., :-1, :])
|
| 577 |
-
past_key_values.append(tuple(layer_past_key_values))
|
| 578 |
-
model_kwargs["past_key_values"] = tuple(past_key_values)
|
| 579 |
|
| 580 |
if model.config.is_encoder_decoder:
|
| 581 |
return GenerateEncoderDecoderOutput(
|
|
|
|
| 325 |
|
| 326 |
if not sequential:
|
| 327 |
# Replicates the new past_key_values to match the `top_k` candidates
|
| 328 |
+
if isinstance(outputs["past_key_values"], DynamicCache) or (
|
| 329 |
+
isinstance(outputs["past_key_values"], EncoderDecoderCache)
|
| 330 |
+
and isinstance(
|
| 331 |
+
outputs["past_key_values"].self_attention_cache, DynamicCache
|
| 332 |
+
)
|
| 333 |
+
):
|
| 334 |
+
model_kwargs["past_key_values"] = model_kwargs["past_key_values"].batch_repeat_interleave(top_k)
|
| 335 |
else:
|
| 336 |
+
raise ValueError(
|
| 337 |
+
f"Unsupported cache type: {type(outputs['past_key_values'])}. Contrastive search requires "
|
| 338 |
+
"dynamic cache, so set `cache_implementation='dynamic'` in the generation config."
|
| 339 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 340 |
|
| 341 |
if sequential:
|
| 342 |
all_outputs = []
|
|
|
|
| 470 |
):
|
| 471 |
next_past_key_values.batch_select_indices(augmented_idx)
|
| 472 |
else:
|
| 473 |
+
raise ValueError(
|
| 474 |
+
f"Unsupported cache type: {type(outputs['past_key_values'])}. Contrastive search requires "
|
| 475 |
+
"dynamic cache, so set `cache_implementation='dynamic'` in the generation config."
|
| 476 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 477 |
|
| 478 |
logit_for_next_step = torch.stack(torch.split(logits, top_k))[
|
| 479 |
range(batch_size), selected_idx, :
|
|
|
|
| 557 |
):
|
| 558 |
model_kwargs["past_key_values"].crop(-1)
|
| 559 |
else:
|
| 560 |
+
raise ValueError(
|
| 561 |
+
f"Unsupported cache type: {type(outputs['past_key_values'])}. Contrastive search requires "
|
| 562 |
+
"dynamic cache, so set `cache_implementation='dynamic'` in the generation config."
|
| 563 |
+
)
|
|
|
|
|
|
|
|
|
|
| 564 |
|
| 565 |
if model.config.is_encoder_decoder:
|
| 566 |
return GenerateEncoderDecoderOutput(
|