Spaces:
Running
Running
| import gradio as gr | |
| from functools import lru_cache | |
| import random | |
| import requests | |
| import logging | |
| import arena_config | |
| import plotly.graph_objects as go | |
| from typing import Dict | |
| from leaderboard import ( | |
| get_current_leaderboard, | |
| update_leaderboard, | |
| start_backup_thread, | |
| get_leaderboard, | |
| get_elo_leaderboard, | |
| ensure_elo_ratings_initialized | |
| ) | |
| import sys | |
| # Initialize logging for errors only | |
| logging.basicConfig(level=logging.ERROR) | |
| logger = logging.getLogger(__name__) | |
| # Start the backup thread | |
| start_backup_thread() | |
| # Function to get available models (using predefined list) | |
| def get_available_models(): | |
| return [model[0] for model in arena_config.APPROVED_MODELS] | |
| # Function to call Ollama API with caching | |
| def call_ollama_api(model, prompt): | |
| payload = { | |
| "model": model, | |
| "messages": [{"role": "user", "content": prompt}], | |
| } | |
| try: | |
| response = requests.post( | |
| f"{arena_config.API_URL}/v1/chat/completions", | |
| headers=arena_config.HEADERS, | |
| json=payload, | |
| timeout=100 | |
| ) | |
| response.raise_for_status() | |
| data = response.json() | |
| return data["choices"][0]["message"]["content"] | |
| except requests.exceptions.RequestException as e: | |
| logger.error(f"Error calling Ollama API for model {model}: {e}") | |
| return f"Error: Unable to get response from the model." | |
| # Generate responses using two randomly selected models | |
| def generate_responses(prompt): | |
| available_models = get_available_models() | |
| if len(available_models) < 2: | |
| return "Error: Not enough models available", "Error: Not enough models available", None, None | |
| selected_models = random.sample(available_models, 2) | |
| model_a, model_b = selected_models | |
| model_a_response = call_ollama_api(model_a, prompt) | |
| model_b_response = call_ollama_api(model_b, prompt) | |
| return model_a_response, model_b_response, model_a, model_b | |
| def battle_arena(prompt): | |
| response_a, response_b, model_a, model_b = generate_responses(prompt) | |
| nickname_a = random.choice(arena_config.model_nicknames) | |
| nickname_b = random.choice(arena_config.model_nicknames) | |
| # Format responses for gr.Chatbot, including the user's prompt | |
| response_a_formatted = [ | |
| {"role": "user", "content": prompt}, | |
| {"role": "assistant", "content": response_a} | |
| ] | |
| response_b_formatted = [ | |
| {"role": "user", "content": prompt}, | |
| {"role": "assistant", "content": response_b} | |
| ] | |
| if random.choice([True, False]): | |
| return ( | |
| response_a_formatted, response_b_formatted, model_a, model_b, | |
| gr.update(label=nickname_a, value=response_a_formatted), | |
| gr.update(label=nickname_b, value=response_b_formatted), | |
| gr.update(interactive=True, value=f"Vote for {nickname_a}"), | |
| gr.update(interactive=True, value=f"Vote for {nickname_b}"), | |
| gr.update(interactive=True, visible=True), # Enable and show Tie button | |
| prompt, # Return the original prompt | |
| 0 # Initialize tie count | |
| ) | |
| else: | |
| return ( | |
| response_b_formatted, response_a_formatted, model_b, model_a, | |
| gr.update(label=nickname_a, value=response_b_formatted), | |
| gr.update(label=nickname_b, value=response_a_formatted), | |
| gr.update(interactive=True, value=f"Vote for {nickname_a}"), | |
| gr.update(interactive=True, value=f"Vote for {nickname_b}"), | |
| gr.update(interactive=True, visible=True), # Enable and show Tie button | |
| prompt, # Return the original prompt | |
| 0 # Initialize tie count | |
| ) | |
| def record_vote(prompt, left_response, right_response, left_model, right_model, choice): | |
| # Check if outputs are generated | |
| if not left_response or not right_response or not left_model or not right_model: | |
| return ( | |
| "Please generate responses before voting.", | |
| gr.update(), | |
| gr.update(interactive=False), | |
| gr.update(interactive=False), | |
| gr.update(visible=False), | |
| gr.update() | |
| ) | |
| winner = left_model if choice == "Left is better" else right_model | |
| loser = right_model if choice == "Left is better" else left_model | |
| # Update the leaderboard | |
| battle_results = update_leaderboard(winner, loser) | |
| result_message = f""" | |
| π Vote recorded! You're awesome! π | |
| π΅ In the left corner: {get_human_readable_name(left_model)} | |
| π΄ In the right corner: {get_human_readable_name(right_model)} | |
| π And the champion you picked is... {get_human_readable_name(winner)}! π₯ | |
| """ | |
| return ( | |
| gr.update(value=result_message, visible=True), # Show result as Markdown | |
| get_leaderboard(), # Update leaderboard | |
| get_elo_leaderboard(), # Add this line | |
| gr.update(interactive=False), # Disable left vote button | |
| gr.update(interactive=False), # Disable right vote button | |
| gr.update(interactive=False), # Disable tie button | |
| gr.update(visible=True), # Show model names | |
| get_leaderboard_chart() # Update leaderboard chart | |
| ) | |
| def get_leaderboard_chart(): | |
| battle_results = get_current_leaderboard() | |
| # Calculate scores and sort results | |
| for model, results in battle_results.items(): | |
| total_battles = results["wins"] + results["losses"] | |
| if total_battles > 0: | |
| win_rate = results["wins"] / total_battles | |
| results["score"] = win_rate * (1 - 1 / (total_battles + 1)) | |
| else: | |
| results["score"] = 0 | |
| sorted_results = sorted( | |
| battle_results.items(), | |
| key=lambda x: (x[1]["score"], x[1]["wins"] + x[1]["losses"]), | |
| reverse=True | |
| ) | |
| models = [get_human_readable_name(model) for model, _ in sorted_results] | |
| wins = [results["wins"] for _, results in sorted_results] | |
| losses = [results["losses"] for _, results in sorted_results] | |
| scores = [results["score"] for _, results in sorted_results] | |
| fig = go.Figure() | |
| # Stacked Bar chart for Wins and Losses | |
| fig.add_trace(go.Bar( | |
| x=models, | |
| y=wins, | |
| name='Wins', | |
| marker_color='#22577a' | |
| )) | |
| fig.add_trace(go.Bar( | |
| x=models, | |
| y=losses, | |
| name='Losses', | |
| marker_color='#38a3a5' | |
| )) | |
| # Line chart for Scores | |
| fig.add_trace(go.Scatter( | |
| x=models, | |
| y=scores, | |
| name='Score', | |
| yaxis='y2', | |
| line=dict(color='#ff7f0e', width=2) | |
| )) | |
| # Update layout for full-width, increased height, and secondary y-axis | |
| fig.update_layout( | |
| title='Model Performance', | |
| xaxis_title='Models', | |
| yaxis_title='Number of Battles', | |
| yaxis2=dict( | |
| title='Score', | |
| overlaying='y', | |
| side='right' | |
| ), | |
| barmode='stack', | |
| height=800, | |
| width=1450, | |
| autosize=True, | |
| legend=dict( | |
| orientation='h', | |
| yanchor='bottom', | |
| y=1.02, | |
| xanchor='right', | |
| x=1 | |
| ) | |
| ) | |
| chart_data = fig.to_json() | |
| print(f"Chart size: {sys.getsizeof(chart_data)} bytes") | |
| return fig | |
| def new_battle(): | |
| nickname_a = random.choice(arena_config.model_nicknames) | |
| nickname_b = random.choice(arena_config.model_nicknames) | |
| return ( | |
| "", # Reset prompt_input | |
| gr.update(value=[], label=nickname_a), # Reset left Chatbot | |
| gr.update(value=[], label=nickname_b), # Reset right Chatbot | |
| None, | |
| None, | |
| gr.update(interactive=False, value=f"Vote for {nickname_a}"), | |
| gr.update(interactive=False, value=f"Vote for {nickname_b}"), | |
| gr.update(interactive=False, visible=False), # Reset Tie button | |
| gr.update(value="", visible=False), | |
| gr.update(), | |
| gr.update(visible=False), | |
| gr.update(), | |
| 0 # Reset tie_count | |
| ) | |
| # Add this new function | |
| def get_human_readable_name(model_name: str) -> str: | |
| model_dict = dict(arena_config.APPROVED_MODELS) | |
| return model_dict.get(model_name, model_name) | |
| # Add this new function to randomly select a prompt | |
| def random_prompt(): | |
| return random.choice(arena_config.example_prompts) | |
| # Modify the continue_conversation function | |
| def continue_conversation(prompt, left_chat, right_chat, left_model, right_model, previous_prompt, tie_count): | |
| # Check if the prompt is empty or the same as the previous one | |
| if not prompt or prompt == previous_prompt: | |
| prompt = random.choice(arena_config.example_prompts) | |
| left_response = call_ollama_api(left_model, prompt) | |
| right_response = call_ollama_api(right_model, prompt) | |
| left_chat.append({"role": "user", "content": prompt}) | |
| left_chat.append({"role": "assistant", "content": left_response}) | |
| right_chat.append({"role": "user", "content": prompt}) | |
| right_chat.append({"role": "assistant", "content": right_response}) | |
| tie_count += 1 | |
| tie_button_state = gr.update(interactive=True) if tie_count < 3 else gr.update(interactive=False, value="Max ties reached. Please vote!") | |
| return ( | |
| gr.update(value=left_chat), | |
| gr.update(value=right_chat), | |
| gr.update(value=""), # Clear the prompt input | |
| tie_button_state, | |
| prompt, # Return the new prompt | |
| tie_count | |
| ) | |
| # Initialize Gradio Blocks | |
| with gr.Blocks(css=""" | |
| #dice-button { | |
| min-height: 90px; | |
| font-size: 35px; | |
| } | |
| """) as demo: | |
| gr.Markdown(arena_config.ARENA_NAME) | |
| gr.Markdown(arena_config.ARENA_DESCRIPTION) | |
| # Battle Arena Tab | |
| with gr.Tab("Battle Arena"): | |
| with gr.Row(): | |
| prompt_input = gr.Textbox( | |
| label="Enter your prompt", | |
| placeholder="Type your prompt here...", | |
| scale=20 | |
| ) | |
| random_prompt_btn = gr.Button("π²", scale=1, elem_id="dice-button") | |
| gr.Markdown("<br>") | |
| # Add the random prompt button functionality | |
| random_prompt_btn.click( | |
| random_prompt, | |
| outputs=prompt_input | |
| ) | |
| submit_btn = gr.Button("Generate Responses", variant="primary") | |
| with gr.Row(): | |
| left_output = gr.Chatbot(label=random.choice(arena_config.model_nicknames), type="messages") | |
| right_output = gr.Chatbot(label=random.choice(arena_config.model_nicknames), type="messages") | |
| with gr.Row(): | |
| left_vote_btn = gr.Button(f"Vote for {left_output.label}", interactive=False) | |
| tie_btn = gr.Button("Tie π Continue with a new prompt", interactive=False, visible=False) | |
| right_vote_btn = gr.Button(f"Vote for {right_output.label}", interactive=False) | |
| result = gr.Textbox(label="Result", interactive=False, visible=False) | |
| with gr.Row(visible=False) as model_names_row: | |
| left_model = gr.Textbox(label="π΅ Left Model", interactive=False) | |
| right_model = gr.Textbox(label="π΄ Right Model", interactive=False) | |
| previous_prompt = gr.State("") # Add this line to store the previous prompt | |
| tie_count = gr.State(0) # Add this line to keep track of tie count | |
| new_battle_btn = gr.Button("New Battle") | |
| # Leaderboard Tab | |
| with gr.Tab("Leaderboard"): | |
| leaderboard = gr.HTML(label="Leaderboard") | |
| # Performance Chart Tab | |
| with gr.Tab("Performance Chart"): | |
| leaderboard_chart = gr.Plot(label="Model Performance Chart") | |
| # ELO Leaderboard Tab | |
| with gr.Tab("ELO Leaderboard"): | |
| elo_leaderboard = gr.HTML(label="ELO Leaderboard") | |
| # Define interactions | |
| submit_btn.click( | |
| battle_arena, | |
| inputs=prompt_input, | |
| outputs=[left_output, right_output, left_model, right_model, | |
| left_output, right_output, left_vote_btn, right_vote_btn, | |
| tie_btn, previous_prompt, tie_count] | |
| ) | |
| left_vote_btn.click( | |
| lambda *args: record_vote(*args, "Left is better"), | |
| inputs=[prompt_input, left_output, right_output, left_model, right_model], | |
| outputs=[result, leaderboard, elo_leaderboard, left_vote_btn, | |
| right_vote_btn, tie_btn, model_names_row, leaderboard_chart] | |
| ) | |
| right_vote_btn.click( | |
| lambda *args: record_vote(*args, "Right is better"), | |
| inputs=[prompt_input, left_output, right_output, left_model, right_model], | |
| outputs=[result, leaderboard, elo_leaderboard, left_vote_btn, | |
| right_vote_btn, tie_btn, model_names_row, leaderboard_chart] | |
| ) | |
| tie_btn.click( | |
| continue_conversation, | |
| inputs=[prompt_input, left_output, right_output, left_model, right_model, previous_prompt, tie_count], | |
| outputs=[left_output, right_output, prompt_input, tie_btn, previous_prompt, tie_count] | |
| ) | |
| new_battle_btn.click( | |
| new_battle, | |
| outputs=[prompt_input, left_output, right_output, left_model, | |
| right_model, left_vote_btn, right_vote_btn, tie_btn, | |
| result, leaderboard, model_names_row, leaderboard_chart, tie_count] | |
| ) | |
| # Update leaderboard and chart on launch | |
| demo.load(get_leaderboard, outputs=leaderboard) | |
| demo.load(get_elo_leaderboard, outputs=elo_leaderboard) | |
| demo.load(get_leaderboard_chart, outputs=leaderboard_chart) | |
| if __name__ == "__main__": | |
| # Initialize ELO ratings before launching the app | |
| ensure_elo_ratings_initialized() | |
| demo.launch(show_api=False) | |