Aktraiser commited on
Commit
c61068e
·
verified ·
1 Parent(s): 2a70625

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +23 -17
handler.py CHANGED
@@ -22,30 +22,36 @@ class EndpointHandler:
22
  def __call__(self, data):
23
  # Extraire le texte d'entrée
24
  if isinstance(data, dict):
25
- text = data.pop("inputs", "")
26
  else:
27
  text = data
28
-
29
- # Paramètres de génération
30
- params = {
31
- "max_new_tokens": data.get("max_new_tokens", 512),
32
- "temperature": data.get("temperature", 0.7),
33
- "top_p": data.get("top_p", 0.95),
34
- "repetition_penalty": data.get("repetition_penalty", 1.15),
35
- "do_sample": data.get("do_sample", True)
 
 
36
  }
37
 
 
 
 
 
38
  try:
39
- # Générer le texte
40
- result = self.pipeline(
41
  text,
42
- **params
43
  )
44
-
45
  # Formater la sortie
46
- if isinstance(result, list):
47
- return {"generated_text": result[0]["generated_text"]}
48
- return {"generated_text": result["generated_text"]}
49
-
50
  except Exception as e:
51
  return {"error": str(e)}
 
22
  def __call__(self, data):
23
  # Extraire le texte d'entrée
24
  if isinstance(data, dict):
25
+ text = data.get("inputs", "")
26
  else:
27
  text = data
28
+
29
+ # Paramètres de génération par défaut
30
+ generation_kwargs = {
31
+ "max_new_tokens": 512,
32
+ "temperature": 0.7,
33
+ "top_p": 0.95,
34
+ "repetition_penalty": 1.15,
35
+ "do_sample": True,
36
+ "pad_token_id": self.tokenizer.pad_token_id,
37
+ "eos_token_id": self.tokenizer.eos_token_id,
38
  }
39
 
40
+ # Mettre à jour avec les paramètres de la requête si fournis
41
+ if isinstance(data, dict) and "parameters" in data:
42
+ generation_kwargs.update(data["parameters"])
43
+
44
  try:
45
+ # Générer la réponse
46
+ outputs = self.pipeline(
47
  text,
48
+ **generation_kwargs
49
  )
50
+
51
  # Formater la sortie
52
+ if isinstance(outputs, list):
53
+ return {"generated_text": outputs[0]["generated_text"]}
54
+ return {"generated_text": outputs["generated_text"]}
55
+
56
  except Exception as e:
57
  return {"error": str(e)}