cyrilvallez HF Staff commited on
Commit
30636bc
·
verified ·
1 Parent(s): a4fcf5b

Update custom_generate/generate.py

Browse files
Files changed (1) hide show
  1. custom_generate/generate.py +19 -34
custom_generate/generate.py CHANGED
@@ -325,25 +325,18 @@ def _contrastive_search(
325
 
326
  if not sequential:
327
  # Replicates the new past_key_values to match the `top_k` candidates
328
- past = model_kwargs["past_key_values"]
329
- # If it is a static cache, modify it in-place layer after layer to save memory
330
- if isinstance(past, DynamicCache) or (
331
- isinstance(past, EncoderDecoderCache)
332
- and isinstance(past.self_attention_cache, DynamicCache)
333
- ):
334
- past.batch_repeat_interleave(top_k)
335
  else:
336
- new_key_values = []
337
- for layer in past:
338
- items = []
339
- # item is either the key or the value matrix
340
- for item in layer:
341
- items.append(item.repeat_interleave(top_k, dim=0))
342
- new_key_values.append(tuple(items))
343
-
344
- past = tuple(new_key_values)
345
-
346
- model_kwargs["past_key_values"] = past
347
 
348
  if sequential:
349
  all_outputs = []
@@ -477,15 +470,10 @@ def _contrastive_search(
477
  ):
478
  next_past_key_values.batch_select_indices(augmented_idx)
479
  else:
480
- new_key_values = []
481
- for layer in next_past_key_values:
482
- items = []
483
- # item is either the key or the value matrix
484
- for item in layer:
485
- items.append(item[augmented_idx, ...])
486
- new_key_values.append(tuple(items))
487
-
488
- next_past_key_values = tuple(new_key_values)
489
 
490
  logit_for_next_step = torch.stack(torch.split(logits, top_k))[
491
  range(batch_size), selected_idx, :
@@ -569,13 +557,10 @@ def _contrastive_search(
569
  ):
570
  model_kwargs["past_key_values"].crop(-1)
571
  else:
572
- past_key_values = []
573
- for layer in model_kwargs["past_key_values"]:
574
- layer_past_key_values = []
575
- for item in layer:
576
- layer_past_key_values.append(item[..., :-1, :])
577
- past_key_values.append(tuple(layer_past_key_values))
578
- model_kwargs["past_key_values"] = tuple(past_key_values)
579
 
580
  if model.config.is_encoder_decoder:
581
  return GenerateEncoderDecoderOutput(
 
325
 
326
  if not sequential:
327
  # Replicates the new past_key_values to match the `top_k` candidates
328
+ if isinstance(outputs["past_key_values"], DynamicCache) or (
329
+ isinstance(outputs["past_key_values"], EncoderDecoderCache)
330
+ and isinstance(
331
+ outputs["past_key_values"].self_attention_cache, DynamicCache
332
+ )
333
+ ):
334
+ model_kwargs["past_key_values"] = model_kwargs["past_key_values"].batch_repeat_interleave(top_k)
335
  else:
336
+ raise ValueError(
337
+ f"Unsupported cache type: {type(outputs['past_key_values'])}. Contrastive search requires "
338
+ "dynamic cache, so set `cache_implementation='dynamic'` in the generation config."
339
+ )
 
 
 
 
 
 
 
340
 
341
  if sequential:
342
  all_outputs = []
 
470
  ):
471
  next_past_key_values.batch_select_indices(augmented_idx)
472
  else:
473
+ raise ValueError(
474
+ f"Unsupported cache type: {type(outputs['past_key_values'])}. Contrastive search requires "
475
+ "dynamic cache, so set `cache_implementation='dynamic'` in the generation config."
476
+ )
 
 
 
 
 
477
 
478
  logit_for_next_step = torch.stack(torch.split(logits, top_k))[
479
  range(batch_size), selected_idx, :
 
557
  ):
558
  model_kwargs["past_key_values"].crop(-1)
559
  else:
560
+ raise ValueError(
561
+ f"Unsupported cache type: {type(outputs['past_key_values'])}. Contrastive search requires "
562
+ "dynamic cache, so set `cache_implementation='dynamic'` in the generation config."
563
+ )
 
 
 
564
 
565
  if model.config.is_encoder_decoder:
566
  return GenerateEncoderDecoderOutput(