Label / app.py
bgamazay's picture
Update app.py
bb0fad9 verified
raw
history blame
13.3 kB
import streamlit as st
import pandas as pd
from PIL import Image, ImageDraw, ImageFont
import io
import os
import socket
import calendar
import re
from typing import Optional
from huggingface_hub import hf_hub_download
# =========================
# Hugging Face Space config
# =========================
HF_REPO_ID = "AIEnergyScore/Leaderboard" # Space slug
HF_REPO_TYPE = "space"
HF_DATA_PREFIX = "data/energy" # path within the Space
# =========================
# Task -> CSV mapping
# =========================
TASK_TO_CSV = {
"Text Generation": "text_generation.csv",
"Reasoning": "reasoning.csv", # now exists in your Space
"Image Generation": "image_generation.csv",
"Text Classification": "text_classification.csv",
"Image Classification": "image_classification.csv",
"Image Captioning": "image_captioning.csv",
"Summarization": "summarization.csv",
"Speech-to-Text (ASR)": "asr.csv",
"Object Detection": "object_detection.csv",
"Question Answering": "question_answering.csv",
"Sentence Similarity": "sentence_similarity.csv",
}
# Back-compat if parts of the code still reference this name:
task_to_file = TASK_TO_CSV
# =========================
# Helpers
# =========================
def read_csv_from_hub(file_name: str) -> pd.DataFrame:
"""
Download a CSV from HF Space path data/energy/<file_name>,
return a pandas DataFrame. Falls back to local if hub unavailable.
"""
hub_path = f"{HF_DATA_PREFIX}/{file_name}"
try:
# helpful DNS check
socket.gethostbyname("huggingface.co")
local_path = hf_hub_download(
repo_id=HF_REPO_ID,
repo_type=HF_REPO_TYPE,
filename=hub_path,
revision="main",
resume_download=True
)
return pd.read_csv(local_path)
except Exception as e:
try:
return pd.read_csv(file_name)
except Exception:
raise RuntimeError(
f"Unable to load '{file_name}' from Hub path '{hub_path}' or locally. "
f"Original error: {e}"
)
def format_with_commas(value) -> str:
"""
Format numeric values with commas and two decimals.
Example: 12345.678 -> '12,345.68'
"""
try:
return f"{float(value):,.2f}"
except Exception:
return str(value)
def _normalize(col: str) -> str:
return re.sub(r"[^a-z0-9]", "", col.strip().lower())
def find_test_date_column(df: pd.DataFrame) -> Optional[str]:
"""
Locate a 'test date' column. Strategy:
1) Exact case-insensitive match 'test date'
2) Any header whose normalized form contains both 'test' and 'date'
3) Fallback to column E (index 4) if present
"""
# (1) exact "test date"
for c in df.columns:
if c.strip().lower() == "test date":
return c
# (2) flexible match
for c in df.columns:
cn = _normalize(c)
if "test" in cn and "date" in cn:
return c
# (3) fallback to E (0-based index 4)
if len(df.columns) >= 5:
return df.columns[4]
return None
def month_abbrev_to_full(abbrev: str) -> Optional[str]:
"""
Map 'Feb' -> 'February', 'Oct' -> 'October'. Returns None if unknown.
"""
if not isinstance(abbrev, str) or not abbrev:
return None
abbr = abbrev.strip()[:3].title() # normalize to 3-letter case 'Oct'
for m in range(1, 13):
if calendar.month_abbr[m] == abbr:
return calendar.month_name[m]
return None
def render_date_from_test_date(value: str) -> str:
"""
Accepts formats:
- 'Oct 2025'
- 'Dec 25' (2-digit year)
Returns 'October 2025' or 'December 2025'.
"""
if not isinstance(value, str):
return ""
s = value.strip()
# Case 1: 'Oct 2025'
m = re.match(r"^([A-Za-z]+)\s+(\d{4})$", s)
if m:
month_full = month_abbrev_to_full(m.group(1))
return f"{month_full} {m.group(2)}" if month_full else ""
# Case 2: 'Dec 25' (map 25 -> 2025)
m2 = re.match(r"^([A-Za-z]+)\s+(\d{2})$", s)
if m2:
month_full = month_abbrev_to_full(m2.group(1))
year_full = f"20{m2.group(2)}"
return f"{month_full} {year_full}" if month_full else ""
return ""
def smart_capitalize(text):
"""Capitalize first letter only if not already; leave rest unchanged."""
if not text:
return text
return text if text[0].isupper() else text[0].upper() + text[1:]
# =========================
# UI / App
# =========================
def main():
# Tag styling
st.markdown(
"""
<style>
.stMultiSelect [data-baseweb="tag"] {
background-color: #3fa45bff !important;
color: white !important;
font-weight: 500;
border-radius: 5px;
padding: 5px 10px;
}
.stMultiSelect [data-baseweb="tag"]:hover { background-color: #358d4d !important; }
.stMultiSelect input { color: black !important; }
</style>
""",
unsafe_allow_html=True,
)
# Sidebar logo & title
with st.sidebar:
col1, col2 = st.columns([1, 5])
with col1:
logo = Image.open("logo.png")
st.image(logo.resize((50, 50)))
with col2:
st.markdown(
"""
<div style="display:flex;align-items:center;gap:10px;margin:0;padding:0;
font-family:'Inter',sans-serif;font-size:26px;font-weight:500;">
AI Energy Score
</div>
""",
unsafe_allow_html=True,
)
st.sidebar.markdown("<hr style='border: 1px solid gray; margin: 15px 0;'>", unsafe_allow_html=True)
st.sidebar.write("### Generate Label:")
# Task order
task_order = [
"Text Generation",
"Reasoning",
"Image Generation",
"Text Classification",
"Image Classification",
"Image Captioning",
"Summarization",
"Speech-to-Text (ASR)",
"Object Detection",
"Question Answering",
"Sentence Similarity",
]
# 1) Select task(s)
st.sidebar.write("#### 1. Select task(s) to view models")
selected_tasks = st.sidebar.multiselect("", options=task_order, default=["Text Generation"])
# Default when nothing selected
default_model_data = {
'provider': "AI Provider",
'model': "Model Name",
'full_model': "AI Provider/Model Name",
'date': "",
'task': "",
'hardware': "",
'energy': 0.0,
'score': 5
}
if not selected_tasks:
model_data = default_model_data
else:
dfs = []
for task in selected_tasks:
file_name = TASK_TO_CSV.get(task)
if not file_name:
st.sidebar.error(f"Unknown task '{task}'.")
continue
try:
df = read_csv_from_hub(file_name)
except FileNotFoundError:
st.sidebar.error(f"Could not find '{file_name}' for task {task}!")
continue
except Exception as e:
st.sidebar.error(f"Error reading '{file_name}' for task {task}: {e}")
continue
# Split provider/model if combined as "Provider/Model"
df['full_model'] = df['model']
df[['provider', 'model']] = df['model'].str.split(pat='/', n=1, expand=True)
# Convert kWh -> Wh (total_gpu_energy is in kWh); keep 2 decimals
df['energy'] = (df['total_gpu_energy'] * 1000).round(2)
# Score
df['score'] = df['energy_score'].fillna(1).astype(int)
# Hardware placeholder (adjust if you have a specific column)
df['hardware'] = "NVIDIA H100-80GB"
df['task'] = task
# --- DATE: Use CSV 'test date' for Text Generation & Reasoning ---
if task in {"Text Generation", "Reasoning"}:
td_col = find_test_date_column(df)
if td_col:
# Try to render; if empty/unparsable, fall back to "February 2025"
df['date'] = df[td_col].apply(render_date_from_test_date)
df['date'] = df['date'].where(df['date'].str.len() > 0, "February 2025")
else:
# If column is missing, explicitly print "February 2025"
df['date'] = "February 2025"
else:
df['date'] = ""
dfs.append(df)
if not dfs:
model_data = default_model_data
else:
data_df = pd.concat(dfs, ignore_index=True)
if data_df.empty:
model_data = default_model_data
else:
model_options = data_df["full_model"].unique().tolist()
selected_model = st.sidebar.selectbox(
"Scored Models",
model_options,
help="Start typing to search for a model"
)
model_data = data_df[data_df["full_model"] == selected_model].iloc[0]
st.sidebar.write("#### 3. Download the label")
try:
score = int(model_data["score"])
background_path = f"{score}.png"
background = Image.open(background_path).convert("RGBA")
except FileNotFoundError:
st.sidebar.error(f"Could not find background image '{score}.png'. Using default background.")
background = Image.open("default_background.png").convert("RGBA")
except ValueError:
st.sidebar.error(f"Invalid score '{model_data['score']}'. Score must be an integer.")
return
final_size = (520, 728)
generated_label = create_label_single_pass(background, model_data, final_size)
st.image(generated_label, caption="Generated Label Preview", width=520)
img_buffer = io.BytesIO()
generated_label.save(img_buffer, format="PNG")
img_buffer.seek(0)
st.sidebar.download_button(
label="Download",
data=img_buffer,
file_name="AIEnergyScore.png",
mime="image/png"
)
st.sidebar.write("#### 4. Share your label!")
st.sidebar.write("[Guidelines](https://huggingface.github.io/AIEnergyScore/#transparency-and-guidelines-for-label-use)")
st.sidebar.markdown("<hr style='border: 1px solid gray; margin: 15px 0;'>", unsafe_allow_html=True)
st.sidebar.write("### Key Links")
st.sidebar.markdown(
"""
<ul style="margin-top:0;margin-bottom:0;padding-left:20px;">
<li><a href="https://huggingface.co/spaces/AIEnergyScore/Leaderboard" target="_blank">Leaderboard</a></li>
<li><a href="https://huggingface.co/spaces/AIEnergyScore/submission_portal" target="_blank">Submission Portal</a></li>
<li><a href="https://huggingface.github.io/AIEnergyScore/#faq" target="_blank">FAQ</a></li>
<li><a href="https://huggingface.github.io/AIEnergyScore/#documentation" target="_blank">Documentation</a></li>
</ul>
""",
unsafe_allow_html=True,
)
def create_label_single_pass(background_image, model_data, final_size=(520, 728)):
bg_resized = background_image.resize(final_size, Image.Resampling.LANCZOS)
# If no task is selected (i.e., using default model_data), return background
if not model_data.get("task"):
return bg_resized
draw = ImageDraw.Draw(bg_resized)
try:
title_font = ImageFont.truetype("Inter_24pt-Bold.ttf", size=27)
details_font = ImageFont.truetype("Inter_18pt-Regular.ttf", size=23)
energy_font = ImageFont.truetype("Inter_18pt-Medium.ttf", size=24)
except Exception as e:
st.error(f"Font loading failed: {e}")
return bg_resized
title_x, title_y = 33, 150
details_x, details_y = 480, 256
energy_x, energy_y = 480, 472 # right-aligned anchors
provider_text = str(model_data['provider'])
model_text = str(model_data['model'])
draw.text((title_x, title_y), provider_text, font=title_font, fill="black")
draw.text((title_x, title_y + 38), model_text, font=title_font, fill="black")
# Right-align details lines (date, task, hardware)
details_lines = [
str(model_data.get('date', "")),
str(model_data.get('task', "")),
str(model_data.get('hardware', "")),
]
for i, line in enumerate(details_lines):
bbox = draw.textbbox((0, 0), line, font=details_font)
text_width = bbox[2] - bbox[0]
draw.text((details_x - text_width, details_y + i * 47), line, font=details_font, fill="black")
# Energy value (two decimals) right-aligned
try:
energy_value = float(model_data.get('energy', 0.0))
except Exception:
energy_value = 0.0
energy_text = format_with_commas(energy_value)
energy_bbox = draw.textbbox((0, 0), energy_text, font=energy_font)
energy_text_width = energy_bbox[2] - energy_bbox[0]
draw.text((energy_x - energy_text_width, energy_y), energy_text, font=energy_font, fill="black")
return bg_resized
if __name__ == "__main__":
main()