kawre's picture
Update app.py
ac72309 verified
# app.py
# Chatbot em cascata para Hugging Face Space / execução local
# - Llama 3.1 (entrada)
# - FLAN-T5 (reformulação)
# - BART (resumo em 3 frases)
#
# Requisitos (no Space): defina HF_TOKEN nos Secrets.
# Variáveis opcionais para troca de modelos:
# - LLAMA_MODEL (padrao: meta-llama/Llama-3.1-8B-Instruct)
# - AUX1_MODEL (padrao: google/flan-t5-large)
# - AUX2_MODEL (padrao: facebook/bart-large-cnn)
#
# Use: python app.py
# Recomendações: requirements.txt com gradio, huggingface-hub, transformers, accelerate, etc.
import os
import traceback
import logging
from typing import List, Dict, Any, Tuple
import gradio as gr
from huggingface_hub import InferenceClient
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("cascade_chatbot")
HF_TOKEN = os.environ.get("HF_TOKEN")
DEFAULT_LLAMA_MODEL = os.environ.get("LLAMA_MODEL", "meta-llama/Llama-3.1-8B-Instruct")
DEFAULT_AUX1 = os.environ.get("AUX1_MODEL", "google/flan-t5-large")
DEFAULT_AUX2 = os.environ.get("AUX2_MODEL", "facebook/bart-large-cnn")
if not HF_TOKEN:
logger.warning("HF_TOKEN não encontrado nas variáveis de ambiente. Configure nos Secrets do Space ou no ambiente local.")
# -------------------------
# Inicializa clientes HF
# -------------------------
try:
client_main = InferenceClient(token=HF_TOKEN, model=DEFAULT_LLAMA_MODEL)
client_aux1 = InferenceClient(token=HF_TOKEN, model=DEFAULT_AUX1)
client_aux2 = InferenceClient(token=HF_TOKEN, model=DEFAULT_AUX2)
except Exception:
logger.exception("Falha ao inicializar InferenceClient(s). Verifique HF_TOKEN e nomes dos modelos.")
# Criar objetos None para evitar crash imediato; erros aparecerão ao tentar usar
client_main = None
client_aux1 = None
client_aux2 = None
# -------------------------
# Helpers
# -------------------------
def _messages_to_prompt(messages: List[Dict[str, str]]) -> str:
lines = []
for m in messages:
role = m.get("role", "user")
content = m.get("content", "")
lines.append(f"{role.upper()}: {content}")
lines.append("ASSISTANT:")
return "\n".join(lines)
def _extract_text_from_response(obj: Any) -> str:
if obj is None:
return ""
# Common atributos
for attr in ("content", "text", "generated_text", "generation_text"):
if hasattr(obj, attr):
try:
v = getattr(obj, attr)
if isinstance(v, str):
return v
return str(v)
except Exception:
pass
try:
choices = None
if hasattr(obj, "choices"):
choices = obj.choices
elif isinstance(obj, dict) and "choices" in obj:
choices = obj["choices"]
if choices:
first = choices[0]
if isinstance(first, dict):
if "message" in first and isinstance(first["message"], dict) and "content" in first["message"]:
return first["message"]["content"]
if "text" in first:
return first["text"]
if "content" in first:
return first["content"]
if hasattr(first, "message"):
msg = first.message
if isinstance(msg, dict) and "content" in msg:
return msg["content"]
if hasattr(first, "text"):
return first.text
except Exception:
pass
try:
if hasattr(obj, "generations") and len(obj.generations) > 0:
g = obj.generations[0]
if isinstance(g, dict) and "text" in g:
return g["text"]
if hasattr(g, "text"):
return g.text
except Exception:
pass
try:
if isinstance(obj, dict):
for k in ("text", "content", "generated_text"):
if k in obj and isinstance(obj[k], str):
return obj[k]
except Exception:
pass
try:
return str(obj)
except Exception:
return ""
# -------------------------
# Chamadas robustas ao InferenceClient
# -------------------------
def call_model_with_messages(client: InferenceClient, messages: List[Dict[str, str]],
max_new_tokens: int = 512, temperature: float = 0.7, top_p: float = 0.95) -> Any:
"""
Tenta múltiplas assinaturas (chat_completion, client.chat, text_generation, etc).
Registra exceções completas para diagnóstico.
"""
def try_call(method, /, *pos_args, **kw_args):
try:
# Não imprimir todo messages no log — resumir
safe_kw = {k: ("[MESSAGES]" if k == "messages" else v) for k, v in kw_args.items()}
logger.info("Tentando %s pos=%s kwargs=%s", getattr(method, "__name__", str(method)), pos_args, safe_kw)
return method(*pos_args, **kw_args)
except Exception:
logger.exception("Falha ao chamar %s", getattr(method, "__name__", str(method)))
return None
# Tentar obter nome do modelo
model_name = getattr(client, "model", None) or DEFAULT_LLAMA_MODEL
# 1) chat_completion
try:
cc = getattr(client, "chat_completion", None)
if cc:
# a) cc(model=..., messages=...)
res = try_call(cc, model=model_name, messages=messages, max_new_tokens=max_new_tokens, temperature=temperature)
if res is not None:
return res
# b) cc(messages=..., model=...)
res = try_call(cc, messages=messages, model=model_name, max_new_tokens=max_new_tokens, temperature=temperature)
if res is not None:
return res
# c) cc.create(...)
if hasattr(cc, "create"):
res = try_call(cc.create, model=model_name, messages=messages, max_new_tokens=max_new_tokens, temperature=temperature)
if res is not None:
return res
# d) positional
res = try_call(cc, messages)
if res is not None:
return res
except Exception:
logger.exception("Erro no bloco chat_completion")
# 2) client.chat namespace
try:
chat_ns = getattr(client, "chat", None)
if chat_ns:
if hasattr(chat_ns, "create"):
res = try_call(chat_ns.create, model=model_name, messages=messages, max_new_tokens=max_new_tokens, temperature=temperature)
if res is not None:
return res
if hasattr(chat_ns, "chat_completion") and hasattr(chat_ns.chat_completion, "create"):
res = try_call(chat_ns.chat_completion.create, model=model_name, messages=messages, max_new_tokens=max_new_tokens, temperature=temperature)
if res is not None:
return res
res = try_call(chat_ns, model_name, messages)
if res is not None:
return res
except Exception:
logger.exception("Erro no bloco chat namespace")
# 3) text_generation
prompt = _messages_to_prompt(messages)
try:
if hasattr(client, "text_generation"):
res = try_call(client.text_generation, prompt=prompt, max_new_tokens=max_new_tokens, temperature=temperature)
if res is not None:
return res
if hasattr(client, "generate") and callable(client.generate):
res = try_call(client.generate, prompt=prompt, max_new_tokens=max_new_tokens)
if res is not None:
return res
except Exception:
logger.exception("Erro no bloco text_generation/generate")
# 4) última tentativa: explorar métodos candidatos
candidate_methods = [m for m in dir(client) if any(k in m for k in ("create", "generate", "complete", "run"))]
for name in candidate_methods:
try:
method = getattr(client, name)
if callable(method):
res = try_call(method, messages=messages)
if res is not None:
return res
res = try_call(method, prompt)
if res is not None:
return res
res = try_call(method, messages)
if res is not None:
return res
except Exception:
logger.exception("Erro testando candidato %s", name)
# falhou todas as tentativas
debug = {"available_attrs": dir(client), "messages_sample": messages[:3]}
logger.error("Todas as tentativas falharam. Debug: %s", debug)
raise RuntimeError(f"Não foi possível chamar o cliente HF com as assinaturas testadas. Debug: {debug}")
# -------------------------
# Pipeline: Llama -> FLAN -> BART
# -------------------------
def pipeline_cascade(user_message: str, system_message: str,
max_tokens: int, temperature: float, top_p: float) -> Tuple[str, List[str]]:
"""
Executa a cascata: Llama (client_main) -> FLAN (client_aux1) -> BART (client_aux2).
Retorna o texto final e um log de passos.
"""
logs = []
# Monta mensagens
messages = [{"role": "system", "content": system_message or ""}, {"role": "user", "content": user_message}]
try:
logs.append("1) Chamando Llama (entrada)")
response_main_obj = call_model_with_messages(client_main, messages, max_new_tokens=max_tokens, temperature=temperature, top_p=top_p)
response_main = _extract_text_from_response(response_main_obj)
logs.append(f"-> Llama respondeu (resumo): {response_main[:300]}")
# Aux1: FLAN-T5 - reformular
logs.append("2) Chamando FLAN-T5 (reformular)")
prompt_aux1 = f"Reformule este texto de forma clara e concisa:\n{response_main}"
try:
if client_aux1 and hasattr(client_aux1, "text_generation"):
res_a1 = client_aux1.text_generation(prompt=prompt_aux1, max_new_tokens=max(128, max_tokens // 4))
elif client_aux1 and hasattr(client_aux1, "completions") and hasattr(client_aux1.completions, "create"):
res_a1 = client_aux1.completions.create(prompt=prompt_aux1, max_new_tokens=max(128, max_tokens // 4))
else:
res_a1 = None
response_aux1 = _extract_text_from_response(res_a1) if res_a1 is not None else response_main
logs.append(f"-> FLAN-T5 respondeu (resumo): {response_aux1[:300]}")
except Exception:
logs.append("FLAN-T5 falhou; usando resposta do Llama")
response_aux1 = response_main
# Aux2: BART - resumo em 3 frases
logs.append("3) Chamando BART (resumo em 3 frases)")
prompt_aux2 = f"Resuma este texto em 3 frases:\n{response_aux1}"
try:
if client_aux2 and hasattr(client_aux2, "text_generation"):
res_a2 = client_aux2.text_generation(prompt=prompt_aux2, max_new_tokens=150)
elif client_aux2 and hasattr(client_aux2, "completions") and hasattr(client_aux2.completions, "create"):
res_a2 = client_aux2.completions.create(prompt=prompt_aux2, max_new_tokens=150)
else:
res_a2 = None
response_aux2 = _extract_text_from_response(res_a2) if res_a2 is not None else response_aux1
logs.append(f"-> BART respondeu (resumo): {response_aux2[:300]}")
except Exception:
logs.append("BART falhou; usando resposta do passo anterior")
response_aux2 = response_aux1
except Exception as e:
tb = traceback.format_exc(limit=5)
logger.exception("Erro pipeline principal: %s", e)
response_aux2 = f"Erro ao gerar resposta: {e}\n\nTraceback (curto):\n{tb}"
logs.append("Erro no pipeline: " + str(e))
return response_aux2, logs
# -------------------------
# Gradio App
# -------------------------
with gr.Blocks(title="Chatbot em Cascata - Llama + FLAN + BART") as demo:
gr.Markdown("## Trabalho Acadêmico FMU - Chatbot em Cascata\n"
"Fluxo: **Llama (entrada)** → **FLAN-T5 (reformulação)** → **BART(resumo)**\n\n"
"Disciplina: INTELIGÊNCIA ARTIFICIAL E APRENDIZADO DE MÁQUINA")
with gr.Row():
with gr.Column(scale=2):
system_message = gr.Textbox(value="Você é um chatbot racional e alegre.",
label="System Message", lines=2)
chatbot = gr.Chatbot(label="Chat")
user_input = gr.Textbox(label="Digite sua mensagem", placeholder="Digite aqui...")
max_tokens = gr.Slider(50, 2048, value=512, step=50, label="Max Tokens")
temperature = gr.Slider(0.0, 1.0, value=0.7, step=0.05, label="Temperature")
top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)")
history = gr.State([])
def submit_handler(msg, history, system_message, max_tokens, temperature, top_p):
# roda pipeline e atualiza histórico
out_text, logs = pipeline_cascade(msg, system_message, int(max_tokens), float(temperature), float(top_p))
history.append({"role": "user", "content": msg})
history.append({"role": "assistant", "content": out_text})
# exibimos também logs no console (útil)
logger.info("Pipeline logs:\n%s", "\n".join(logs))
return history, history
user_input.submit(submit_handler,
inputs=[user_input, history, system_message, max_tokens, temperature, top_p],
outputs=[chatbot, history])
btn_send = gr.Button("Enviar")
btn_send.click(submit_handler,
inputs=[user_input, history, system_message, max_tokens, temperature, top_p],
outputs=[chatbot, history])
with gr.Column(scale=1):
gr.Markdown("### Informações sobre o Projeto\n"
"Painel feito para descrever as **configurações**, **testar a geração** e sobre os **envolvidos**:")
model_info_md = f"""
**Modelos usados:**
- Llama (input): `{DEFAULT_LLAMA_MODEL}`
- Aux 1 (reformulação): `{DEFAULT_AUX1}`
- Aux 2 (resumo): `{DEFAULT_AUX2}`
**Como foram configurados:**
- Cada modelo é instanciado via `InferenceClient(token=HF_TOKEN, model=<model_name>)`.
- Chamadas preferenciais:
- Para chat: `client.chat_completion(messages=..., model=...)` (quando disponível)
- Fallback: `client.text_generation(prompt=...)`
- Ajustes de inferência controlados pelo usuário: `max_tokens`, `temperature`, `top_p`.
- Logs de diagnóstico são gravados (úteis se houver erros de assinatura/permissão).
"""
gr.Markdown(model_info_md)
# Self-test: roda testes com mensagens predefinidas e mostra o resultado
test_output = gr.Textbox(label="Resultado do Self-Test", lines=12, interactive=False)
def run_self_test(system_message, max_tokens, temperature, top_p):
msgs = [
"Explique resumidamente o que é a técnica de regressão linear.",
"Resuma em 1 frase as vantagens de usar validação cruzada.",
"Como posso autenticar usuários em uma aplicação web?"
]
accumulated = []
for m in msgs:
out, logs = pipeline_cascade(m, system_message, int(max_tokens), float(temperature), float(top_p))
accumulated.append("INPUT: " + m)
accumulated.append("OUTPUT: " + out)
accumulated.append("LOGS: " + " | ".join(logs))
accumulated.append("-" * 40)
return "\n".join(accumulated)
btn_test = gr.Button("Run self-test")
btn_test.click(run_self_test, inputs=[system_message, max_tokens, temperature, top_p], outputs=[test_output])
gr.Markdown(
"### Disciplina: INTELIGÊNCIA ARTIFICIAL E APRENDIZADO DE MÁQUINA\n"
"- Trabalho N2\n"
"- Turma Noturna de Bacharelado em Ciências da Computação 2025.\n"
"- Integrantes:\n "
"- Lucas Antonini - 1722631\n "
"- Carlos Eduardo da Silva - 1961011\n "
"- Felipe Rios Amaral - 1847080 \n"
"- Kawrê Britto de Oliveira - 2260931\n"
"- Miguel Putini Alfano - 2879347 ")
if __name__ == "__main__":
demo.launch()