habdine commited on
Commit
ac9448d
·
verified ·
1 Parent(s): f14174b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -7
app.py CHANGED
@@ -56,8 +56,8 @@ EXAMPLE_SEQUENCES = [
56
  ],
57
  ]
58
 
59
- MAX_MAX_NEW_TOKENS = 256
60
- DEFAULT_MAX_NEW_TOKENS = 100
61
 
62
 
63
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
@@ -90,7 +90,7 @@ def stream_response(
90
  temperature: float = 0.6,
91
  top_p: float = 0.9,
92
  top_k: int = 50,
93
- repetition_penalty: float = 1.2,
94
  ) -> Iterator[str]:
95
 
96
 
@@ -240,24 +240,24 @@ with gr.Blocks(theme=theme, css_paths="style.css", fill_height=True) as demo:
240
  do_sample_checkbox = gr.Checkbox(label="Enable sampling", value=False)
241
  temperature_slider = gr.Slider(
242
  label="Temperature",
243
- minimum=0.1,
244
  maximum=4.0,
245
  step=0.1,
246
- value=0.6,
247
  )
248
  top_p_slider = gr.Slider(
249
  label="Top-p (nucleus sampling)",
250
  minimum=0.05,
251
  maximum=1.0,
252
  step=0.05,
253
- value=0.9,
254
  )
255
  top_k_slider = gr.Slider(
256
  label="Top-k",
257
  minimum=1,
258
  maximum=1000,
259
  step=1,
260
- value=50,
261
  )
262
  repetition_penalty_slider = gr.Slider(
263
  label="Repetition penalty",
 
56
  ],
57
  ]
58
 
59
+ MAX_MAX_NEW_TOKENS = 1024
60
+ DEFAULT_MAX_NEW_TOKENS = 512
61
 
62
 
63
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
90
  temperature: float = 0.6,
91
  top_p: float = 0.9,
92
  top_k: int = 50,
93
+ repetition_penalty: float = 1.0,
94
  ) -> Iterator[str]:
95
 
96
 
 
240
  do_sample_checkbox = gr.Checkbox(label="Enable sampling", value=False)
241
  temperature_slider = gr.Slider(
242
  label="Temperature",
243
+ minimum=0.0,
244
  maximum=4.0,
245
  step=0.1,
246
+ value=0.0,
247
  )
248
  top_p_slider = gr.Slider(
249
  label="Top-p (nucleus sampling)",
250
  minimum=0.05,
251
  maximum=1.0,
252
  step=0.05,
253
+ value=1.0,
254
  )
255
  top_k_slider = gr.Slider(
256
  label="Top-k",
257
  minimum=1,
258
  maximum=1000,
259
  step=1,
260
+ value=1,
261
  )
262
  repetition_penalty_slider = gr.Slider(
263
  label="Repetition penalty",