Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -7,6 +7,8 @@ tokenizer_splade = None
|
|
| 7 |
model_splade = None
|
| 8 |
tokenizer_splade_lexical = None
|
| 9 |
model_splade_lexical = None
|
|
|
|
|
|
|
| 10 |
|
| 11 |
# Load SPLADE v3 model (original)
|
| 12 |
try:
|
|
@@ -29,6 +31,18 @@ except Exception as e:
|
|
| 29 |
print(f"Error loading SPLADE v3 Lexical model: {e}")
|
| 30 |
print(f"Please ensure '{splade_lexical_model_name}' is accessible (check Hugging Face Hub for potential agreements).")
|
| 31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
# --- Helper function for lexical mask ---
|
| 33 |
def create_lexical_bow_mask(input_ids, vocab_size, tokenizer):
|
| 34 |
"""
|
|
@@ -107,7 +121,7 @@ def get_splade_representation(text):
|
|
| 107 |
return formatted_output
|
| 108 |
|
| 109 |
|
| 110 |
-
def get_splade_lexical_representation(text):
|
| 111 |
if tokenizer_splade_lexical is None or model_splade_lexical is None:
|
| 112 |
return "SPLADE v3 Lexical model is not loaded. Please check the console for loading errors."
|
| 113 |
|
|
@@ -167,12 +181,74 @@ def get_splade_lexical_representation(text): # Removed apply_lexical_mask parame
|
|
| 167 |
return formatted_output
|
| 168 |
|
| 169 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
# --- Unified Prediction Function for Gradio ---
|
| 171 |
def predict_representation(model_choice, text):
|
| 172 |
if model_choice == "SPLADE (cocondenser)":
|
| 173 |
return get_splade_representation(text)
|
| 174 |
-
elif model_choice == "SPLADE-v3-Lexical":
|
| 175 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 176 |
else:
|
| 177 |
return "Please select a model."
|
| 178 |
|
|
@@ -182,8 +258,10 @@ demo = gr.Interface(
|
|
| 182 |
inputs=[
|
| 183 |
gr.Radio(
|
| 184 |
[
|
| 185 |
-
"SPLADE
|
| 186 |
-
"SPLADE-v3-Lexical
|
|
|
|
|
|
|
| 187 |
],
|
| 188 |
label="Choose Representation Model",
|
| 189 |
value="SPLADE (cocondenser)" # Default selection
|
|
@@ -196,7 +274,7 @@ demo = gr.Interface(
|
|
| 196 |
],
|
| 197 |
outputs=gr.Markdown(),
|
| 198 |
title="🌌 Sparse Representation Generator",
|
| 199 |
-
description="Enter any text to see its SPLADE sparse vector.
|
| 200 |
allow_flagging="never"
|
| 201 |
)
|
| 202 |
|
|
|
|
| 7 |
model_splade = None
|
| 8 |
tokenizer_splade_lexical = None
|
| 9 |
model_splade_lexical = None
|
| 10 |
+
tokenizer_splade_doc = None # New tokenizer for SPLADE-v3-Doc
|
| 11 |
+
model_splade_doc = None # New model for SPLADE-v3-Doc
|
| 12 |
|
| 13 |
# Load SPLADE v3 model (original)
|
| 14 |
try:
|
|
|
|
| 31 |
print(f"Error loading SPLADE v3 Lexical model: {e}")
|
| 32 |
print(f"Please ensure '{splade_lexical_model_name}' is accessible (check Hugging Face Hub for potential agreements).")
|
| 33 |
|
| 34 |
+
# Load SPLADE v3 Doc model (NEW)
|
| 35 |
+
try:
|
| 36 |
+
splade_doc_model_name = "naver/splade-v3-doc"
|
| 37 |
+
tokenizer_splade_doc = AutoTokenizer.from_pretrained(splade_doc_model_name)
|
| 38 |
+
model_splade_doc = AutoModelForMaskedLM.from_pretrained(splade_doc_model_name)
|
| 39 |
+
model_splade_doc.eval() # Set to evaluation mode for inference
|
| 40 |
+
print(f"SPLADE v3 Doc model '{splade_doc_model_name}' loaded successfully!")
|
| 41 |
+
except Exception as e:
|
| 42 |
+
print(f"Error loading SPLADE v3 Doc model: {e}")
|
| 43 |
+
print(f"Please ensure '{splade_doc_model_name}' is accessible (check Hugging Face Hub for potential agreements).")
|
| 44 |
+
|
| 45 |
+
|
| 46 |
# --- Helper function for lexical mask ---
|
| 47 |
def create_lexical_bow_mask(input_ids, vocab_size, tokenizer):
|
| 48 |
"""
|
|
|
|
| 121 |
return formatted_output
|
| 122 |
|
| 123 |
|
| 124 |
+
def get_splade_lexical_representation(text):
|
| 125 |
if tokenizer_splade_lexical is None or model_splade_lexical is None:
|
| 126 |
return "SPLADE v3 Lexical model is not loaded. Please check the console for loading errors."
|
| 127 |
|
|
|
|
| 181 |
return formatted_output
|
| 182 |
|
| 183 |
|
| 184 |
+
# NEW: Function for SPLADE-v3-Doc representation
|
| 185 |
+
def get_splade_doc_representation(text, apply_lexical_mask: bool):
|
| 186 |
+
if tokenizer_splade_doc is None or model_splade_doc is None:
|
| 187 |
+
return "SPLADE v3 Doc model is not loaded. Please check the console for loading errors."
|
| 188 |
+
|
| 189 |
+
inputs = tokenizer_splade_doc(text, return_tensors="pt", padding=True, truncation=True)
|
| 190 |
+
inputs = {k: v.to(model_splade_doc.device) for k, v in inputs.items()}
|
| 191 |
+
|
| 192 |
+
with torch.no_grad():
|
| 193 |
+
output = model_splade_doc(**inputs)
|
| 194 |
+
|
| 195 |
+
if hasattr(output, 'logits'):
|
| 196 |
+
splade_vector = torch.max(
|
| 197 |
+
torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1),
|
| 198 |
+
dim=1
|
| 199 |
+
)[0].squeeze()
|
| 200 |
+
else:
|
| 201 |
+
return "Model output structure not as expected for SPLADE v3 Doc. 'logits' not found."
|
| 202 |
+
|
| 203 |
+
# --- Apply Lexical Mask if requested ---
|
| 204 |
+
if apply_lexical_mask:
|
| 205 |
+
vocab_size = tokenizer_splade_doc.vocab_size
|
| 206 |
+
bow_mask = create_lexical_bow_mask(
|
| 207 |
+
inputs['input_ids'], vocab_size, tokenizer_splade_doc
|
| 208 |
+
).squeeze()
|
| 209 |
+
splade_vector = splade_vector * bow_mask
|
| 210 |
+
# --- End Lexical Mask Logic ---
|
| 211 |
+
|
| 212 |
+
indices = torch.nonzero(splade_vector).squeeze().cpu().tolist()
|
| 213 |
+
if not isinstance(indices, list):
|
| 214 |
+
indices = [indices]
|
| 215 |
+
|
| 216 |
+
values = splade_vector[indices].cpu().tolist()
|
| 217 |
+
token_weights = dict(zip(indices, values))
|
| 218 |
+
|
| 219 |
+
meaningful_tokens = {}
|
| 220 |
+
for token_id, weight in token_weights.items():
|
| 221 |
+
decoded_token = tokenizer_splade_doc.decode([token_id])
|
| 222 |
+
if decoded_token not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(decoded_token.strip()) > 0:
|
| 223 |
+
meaningful_tokens[decoded_token] = weight
|
| 224 |
+
|
| 225 |
+
sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[1], reverse=True)
|
| 226 |
+
|
| 227 |
+
formatted_output = "SPLADE v3 Doc Representation (All Non-Zero Terms):\n"
|
| 228 |
+
if not sorted_representation:
|
| 229 |
+
formatted_output += "No significant terms found for this input.\n"
|
| 230 |
+
else:
|
| 231 |
+
for term, weight in sorted_representation:
|
| 232 |
+
formatted_output += f"- **{term}**: {weight:.4f}\n"
|
| 233 |
+
|
| 234 |
+
formatted_output += "\n--- Raw SPLADE Vector Info ---\n"
|
| 235 |
+
formatted_output += f"Total non-zero terms in vector: {len(indices)}\n"
|
| 236 |
+
formatted_output += f"Sparsity: {1 - (len(indices) / tokenizer_splade_doc.vocab_size):.2%}\n"
|
| 237 |
+
|
| 238 |
+
return formatted_output
|
| 239 |
+
|
| 240 |
+
|
| 241 |
# --- Unified Prediction Function for Gradio ---
|
| 242 |
def predict_representation(model_choice, text):
|
| 243 |
if model_choice == "SPLADE (cocondenser)":
|
| 244 |
return get_splade_representation(text)
|
| 245 |
+
elif model_choice == "SPLADE-v3-Lexical":
|
| 246 |
+
# Always applies lexical mask for this option as per last request
|
| 247 |
+
return get_splade_lexical_representation(text)
|
| 248 |
+
elif model_choice == "SPLADE-v3-Doc (with expansion)": # New option
|
| 249 |
+
return get_splade_doc_representation(text, apply_lexical_mask=False)
|
| 250 |
+
elif model_choice == "SPLADE-v3-Doc (lexical-only)": # New option
|
| 251 |
+
return get_splade_doc_representation(text, apply_lexical_mask=True)
|
| 252 |
else:
|
| 253 |
return "Please select a model."
|
| 254 |
|
|
|
|
| 258 |
inputs=[
|
| 259 |
gr.Radio(
|
| 260 |
[
|
| 261 |
+
"SPLADE (cocondenser)",
|
| 262 |
+
"SPLADE-v3-Lexical", # Lexical-only by default now
|
| 263 |
+
"SPLADE-v3-Doc (with expansion)", # Option to see full neural output
|
| 264 |
+
"SPLADE-v3-Doc (lexical-only)" # Option with lexical mask applied
|
| 265 |
],
|
| 266 |
label="Choose Representation Model",
|
| 267 |
value="SPLADE (cocondenser)" # Default selection
|
|
|
|
| 274 |
],
|
| 275 |
outputs=gr.Markdown(),
|
| 276 |
title="🌌 Sparse Representation Generator",
|
| 277 |
+
description="Enter any text to see its SPLADE sparse vector. Explore different SPLADE models and their expansion behaviors.",
|
| 278 |
allow_flagging="never"
|
| 279 |
)
|
| 280 |
|