Tom Claude commited on
Commit
84f99ae
Β·
1 Parent(s): ef7af85

feat: Implement hybrid search with word boundaries, reorder UI, and add user API key management

Browse files

Major improvements to search relevancy, UX, and security:

**Search Optimization**:
- Implement PostgreSQL regex with word boundaries (\m \M) for exact matching
- Fix false positives (e.g., "F1" no longer matches "profile" or "if")
- Update Vanna system prompt with regex guidance and examples
- Create query function templates with hybrid search support

**UI Improvements**:
- Reorder modes: Inspiration (default) β†’ Refinement β†’ Chart
- Rename buttons for clarity and brevity
- Update app description to reflect all modes
- Change "Voir la source" to "Source" for consistency

**API Key Management**:
- Users now provide their own Datawrapper API keys
- Persistent storage via browser localStorage
- Session state management with validation
- Yellow warning box for permissions requirements
- Graceful error handling for missing/invalid keys
- Remove hardcoded DATAWRAPPER_ACCESS_TOKEN dependency

**New Files**:
- src/query_intent_classifier.py: Intent classification for hybrid search
- src/vanna_query_functions.py: SQL template functions with regex

**Technical Details**:
- Word boundary regex: ~* operator with \m and \M markers
- Hybrid search combines tag matching + keyword search with OR logic
- LEFT JOINs ensure untagged posts (7,245+) are included
- JavaScript localStorage integration for API key persistence

πŸ€– Generated with Claude Code (https://claude.com/claude-code)

Co-Authored-By: Claude <[email protected]>

Files changed (4) hide show
  1. app.py +272 -73
  2. src/query_intent_classifier.py +238 -0
  3. src/vanna.py +87 -32
  4. src/vanna_query_functions.py +300 -0
app.py CHANGED
@@ -10,6 +10,7 @@ Now with Datawrapper integration for chart generation!
10
  import os
11
  import io
12
  import asyncio
 
13
  import pandas as pd
14
  import gradio as gr
15
  from dotenv import load_dotenv
@@ -18,6 +19,7 @@ from src.datawrapper_client import create_and_publish_chart, get_iframe_html
18
  from datetime import datetime, timedelta
19
  from collections import defaultdict
20
  from src.vanna import VannaComponent
 
21
 
22
  # Load environment variables
23
  load_dotenv()
@@ -54,6 +56,32 @@ except Exception as e:
54
  print(f"βœ— Error initializing Vanna: {e}")
55
  raise
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  def check_rate_limit(request: gr.Request) -> tuple[bool, int]:
58
  """Check if user has exceeded rate limit"""
59
  if request is None:
@@ -110,23 +138,41 @@ def recommend_stream(message: str, history: list, request: gr.Request):
110
  yield f"Error generating response: {str(e)}\n\nPlease check your environment variables (HF_TOKEN, SUPABASE_URL, SUPABASE_KEY) and try again."
111
 
112
 
113
- def generate_chart_from_csv(csv_file, user_prompt):
114
  """
115
- Generate a Datawrapper chart from uploaded CSV and user prompt.
116
 
117
  Args:
118
  csv_file: Uploaded CSV file
119
  user_prompt: User's description of the chart
 
120
 
121
  Returns:
122
  HTML string with iframe or error message
123
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  if not csv_file:
125
  return "<div style='padding: 50px; text-align: center;'>Please upload a CSV file to generate a chart.</div>"
126
 
127
  if not user_prompt or user_prompt.strip() == "":
128
  return "<div style='padding: 50px; text-align: center;'>Please describe what chart you want to create.</div>"
129
 
 
 
 
 
130
  try:
131
  # Show loading message
132
  loading_html = """
@@ -192,9 +238,15 @@ def generate_chart_from_csv(csv_file, user_prompt):
192
  <div style='padding: 50px; text-align: center; color: red;'>
193
  <h3>❌ Error</h3>
194
  <p>{str(e)}</p>
195
- <p style='font-size: 0.9em; color: #666;'>Please ensure your CSV is properly formatted and try again.</p>
196
  </div>
197
  """
 
 
 
 
 
 
198
 
199
  def csv_to_cards_html(csv_text: str) -> str:
200
  """
@@ -211,11 +263,7 @@ def csv_to_cards_html(csv_text: str) -> str:
211
  source_url = row.get("source_url", "#")
212
  author = row.get("author", "Inconnu")
213
  published_date = row.get("published_date", "")
214
- if not published_date == "nan":
215
- published_date = ""
216
- image_url = row.get("image_url", "")
217
- if not image_url == "nan":
218
- image_url = "https://fpoimg.com/800x600?text=Image+not+found"
219
 
220
  cards_html += f"""
221
  <div style="background: white; border-radius: 10px; box-shadow: 0 2px 8px rgba(0,0,0,0.1);
@@ -227,7 +275,7 @@ def csv_to_cards_html(csv_text: str) -> str:
227
  <p style="margin:0; color:#999; font-size:0.8em;">{published_date}</p>
228
  <a href="{source_url}" target="_blank"
229
  style="display:inline-block; margin-top:8px; font-size:0.9em; color:#1976d2; text-decoration:none;">
230
- πŸ”— Voir la source
231
  </a>
232
  </div>
233
  </div>
@@ -262,20 +310,60 @@ async def search_inspiration_from_database(user_prompt):
262
  """
263
 
264
  try:
265
- response = await vanna.ask(user_prompt)
266
- print("response :", repr(response))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
267
 
268
  clean_response = response.strip()
269
 
270
- if clean_response.startswith("⚠️") or "Aucun CSV détecté" in clean_response:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
271
  return f"""
272
  <div style='padding: 50px; text-align: center; color: #d9534f;'>
273
- <h3>❌ No valid data found</h3>
274
- <p>The AI couldn't generate any data for this request. Try being more specific β€” for example:
275
- <em>"Show me spotlights from 2020 about design"</em>.</p>
 
276
  </div>
277
  """
278
 
 
279
  csv_text = (
280
  clean_response
281
  .strip("```")
@@ -283,11 +371,15 @@ async def search_inspiration_from_database(user_prompt):
283
  .replace("CSV", "")
284
  )
285
 
286
- if "," not in csv_text:
 
287
  return f"""
288
  <div style='padding: 50px; text-align: center; color: #d9534f;'>
289
- <h3>❌ No valid CSV detected</h3>
290
- <p>The model didn't return any structured data. Try rephrasing your query to be more precise.</p>
 
 
 
291
  </div>
292
  """
293
 
@@ -295,11 +387,17 @@ async def search_inspiration_from_database(user_prompt):
295
  return cards_html
296
 
297
  except Exception as e:
 
 
 
298
  return f"""
299
  <div style='padding: 50px; text-align: center; color: red;'>
300
- <h3>❌ Error</h3>
301
- <p>{str(e)}</p>
302
- <p style='font-size: 0.9em; color: #666;'>Please try again.</p>
 
 
 
303
  </div>
304
  """
305
 
@@ -332,18 +430,63 @@ with gr.Blocks(
332
  gr.Markdown("""
333
  # πŸ“Š Viz LLM
334
 
335
- Get design recommendations or generate charts with AI-powered data visualization assistance.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
336
  """)
337
 
338
- # Mode selector buttons
339
  with gr.Row():
340
- ideation_btn = gr.Button("πŸ’‘ Ideation Mode", variant="primary", elem_classes="mode-button")
341
- chart_gen_btn = gr.Button("πŸ“Š Chart Generation Mode", variant="secondary", elem_classes="mode-button")
342
- inspiration_btn = gr.Button("✨ Inspiration Mode", variant="secondary", elem_classes="mode-button")
 
 
 
 
 
 
 
 
 
 
 
 
343
 
 
344
 
345
- # Ideation Mode: Chat interface (shown by default, wrapped in Column)
346
- with gr.Column(visible=True) as ideation_container:
347
  ideation_interface = gr.ChatInterface(
348
  fn=recommend_stream,
349
  type="messages",
@@ -360,6 +503,32 @@ with gr.Blocks(
360
 
361
  # Chart Generation Mode: Chart controls and output (hidden by default)
362
  with gr.Column(visible=False) as chart_gen_container:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
363
  csv_upload = gr.File(
364
  label="πŸ“ Upload CSV File",
365
  file_types=[".csv"],
@@ -379,79 +548,111 @@ with gr.Blocks(
379
  label="Generated Chart"
380
  )
381
 
382
- # Inspiration Mode:
383
- with gr.Column(visible=False) as inspiration_container:
384
- with gr.Row():
385
- inspiration_prompt_input = gr.Textbox(
386
- placeholder="Ask for an inspiration...",
387
- show_label=False,
388
- scale=4,
389
- container=False
390
- )
391
- inspiration_search_btn = gr.Button("πŸ” Search", variant="primary", scale=1)
392
 
393
- inspiration_cards_html = gr.HTML("")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
394
 
395
- # Mode switching functions
396
  def switch_to_ideation():
397
  return [
 
398
  gr.update(variant="primary"), # ideation_btn
399
  gr.update(variant="secondary"), # chart_gen_btn
400
- gr.update(variant="secondary"), # inspiration_btn
401
  gr.update(visible=True), # ideation_container
402
  gr.update(visible=False), # chart_gen_container
403
- gr.update(visible=False), # inspiration_container
404
  ]
405
 
406
  def switch_to_chart_gen():
407
  return [
 
408
  gr.update(variant="secondary"), # ideation_btn
409
  gr.update(variant="primary"), # chart_gen_btn
410
- gr.update(variant="secondary"), # inspiration_btn
411
  gr.update(visible=False), # ideation_container
412
  gr.update(visible=True), # chart_gen_container
413
- gr.update(visible=False), # inspiration_container
414
  ]
415
 
416
- def switch_to_inspiration():
417
- return [
418
- gr.update(variant="secondary"), # ideation_btn
419
- gr.update(variant="secondary"), # chart_gen_btn
420
- gr.update(variant="primary"), # inspiration_btn
421
- gr.update(visible=False), # ideation_container
422
- gr.update(visible=False), # chart_gen_container
423
- gr.update(visible=True), # inspiration_container
424
- ]
425
 
426
- # Wire up mode switching
427
  ideation_btn.click(
428
  fn=switch_to_ideation,
429
  inputs=[],
430
- outputs=[ideation_btn, chart_gen_btn, inspiration_btn, ideation_container, chart_gen_container, inspiration_container]
431
  )
432
 
433
  chart_gen_btn.click(
434
  fn=switch_to_chart_gen,
435
  inputs=[],
436
- outputs=[ideation_btn, chart_gen_btn, inspiration_btn, ideation_container, chart_gen_container, inspiration_container]
437
  )
438
 
439
- inspiration_btn.click(
440
- fn=switch_to_inspiration,
441
- inputs=[],
442
- outputs=[ideation_btn, chart_gen_btn, inspiration_btn, ideation_container, chart_gen_container, inspiration_container]
 
 
443
  )
444
 
445
- # Generate chart when button is clicked
446
  generate_chart_btn.click(
447
  fn=generate_chart_from_csv,
448
- inputs=[csv_upload, chart_prompt_input],
449
  outputs=[chart_output]
450
  )
451
 
452
- # Search inspiration when button is clicked
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
453
  inspiration_search_btn.click(
454
- fn=search_inspiration_from_database,
455
  inputs=[inspiration_prompt_input],
456
  outputs=[inspiration_cards_html]
457
  )
@@ -460,11 +661,11 @@ with gr.Blocks(
460
  gr.Markdown("""
461
  ### About Viz LLM
462
 
463
- **Ideation Mode:** Get design recommendations based on research papers, design principles, and examples from the field of information graphics and data visualization.
 
 
464
 
465
- **Chart Generation Mode:** Upload your CSV data and describe your visualization goal. The AI will analyze your data, select the optimal chart type, and generate a publication-ready chart using Datawrapper.
466
-
467
- **Inspiration Mode:** Coming soon.
468
 
469
  **Credits:** Special thanks to the researchers whose work informed this model: Robert Kosara, Edward Segel, Jeffrey Heer, Matthew Conlen, John Maeda, Kennedy Elliott, Scott McCloud, and many others.
470
 
@@ -473,21 +674,19 @@ with gr.Blocks(
473
  **Usage Limits:** This service is limited to 20 queries per day per user to manage costs. Responses are optimized for English.
474
 
475
  <div style="text-align: center; margin-top: 20px; opacity: 0.6; font-size: 0.9em;">
476
- Embeddings: Jina-CLIP-v2 | Charts: Datawrapper API
477
  </div>
478
  """)
479
 
480
  # Launch configuration
481
  if __name__ == "__main__":
482
- # Check for required environment variables
483
- required_vars = ["SUPABASE_URL", "SUPABASE_KEY", "HF_TOKEN", "DATAWRAPPER_ACCESS_TOKEN"]
484
  missing_vars = [var for var in required_vars if not os.getenv(var)]
485
 
486
  if missing_vars:
487
  print(f"⚠️ Warning: Missing environment variables: {', '.join(missing_vars)}")
488
  print("Please set these in your .env file or as environment variables")
489
- if "DATAWRAPPER_ACCESS_TOKEN" in missing_vars:
490
- print("Note: DATAWRAPPER_ACCESS_TOKEN is required for chart generation mode")
491
 
492
  # Launch the app
493
  demo.launch(
 
10
  import os
11
  import io
12
  import asyncio
13
+ import time
14
  import pandas as pd
15
  import gradio as gr
16
  from dotenv import load_dotenv
 
19
  from datetime import datetime, timedelta
20
  from collections import defaultdict
21
  from src.vanna import VannaComponent
22
+ from src.query_intent_classifier import classify_query, IntentClassifier
23
 
24
  # Load environment variables
25
  load_dotenv()
 
56
  print(f"βœ— Error initializing Vanna: {e}")
57
  raise
58
 
59
+ # CSV cleanup function
60
+ def cleanup_old_csv_files():
61
+ """Delete CSV files older than 24 hours to prevent accumulation"""
62
+ folder = "513935c4d2db2d2d"
63
+ if not os.path.exists(folder):
64
+ return
65
+
66
+ cleaned = 0
67
+ for file in os.listdir(folder):
68
+ if file.endswith(".csv"):
69
+ file_path = os.path.join(folder, file)
70
+ try:
71
+ # Check if file is older than 24 hours
72
+ if os.path.getmtime(file_path) < time.time() - 86400:
73
+ os.remove(file_path)
74
+ cleaned += 1
75
+ except Exception as e:
76
+ print(f"Warning: Could not delete {file_path}: {e}")
77
+
78
+ if cleaned > 0:
79
+ print(f"βœ“ Cleaned up {cleaned} old CSV files")
80
+
81
+ # Run cleanup on startup
82
+ print("Cleaning up old CSV files...")
83
+ cleanup_old_csv_files()
84
+
85
  def check_rate_limit(request: gr.Request) -> tuple[bool, int]:
86
  """Check if user has exceeded rate limit"""
87
  if request is None:
 
138
  yield f"Error generating response: {str(e)}\n\nPlease check your environment variables (HF_TOKEN, SUPABASE_URL, SUPABASE_KEY) and try again."
139
 
140
 
141
+ def generate_chart_from_csv(csv_file, user_prompt, api_key):
142
  """
143
+ Generate a Datawrapper chart from uploaded CSV and user prompt using user's API key.
144
 
145
  Args:
146
  csv_file: Uploaded CSV file
147
  user_prompt: User's description of the chart
148
+ api_key: User's Datawrapper API key
149
 
150
  Returns:
151
  HTML string with iframe or error message
152
  """
153
+ # Validate API key first
154
+ if not api_key or api_key.strip() == "":
155
+ return """
156
+ <div style='padding: 50px; text-align: center; color: #d9534f;'>
157
+ <h3>❌ No API Key Provided</h3>
158
+ <p>Please enter your Datawrapper API key above to generate charts.</p>
159
+ <p style='margin-top: 15px;'>
160
+ <a href='https://app.datawrapper.de/account/api-tokens' target='_blank'
161
+ style='color: #1976d2; text-decoration: underline;'>Get your API key β†’</a>
162
+ </p>
163
+ </div>
164
+ """
165
+
166
  if not csv_file:
167
  return "<div style='padding: 50px; text-align: center;'>Please upload a CSV file to generate a chart.</div>"
168
 
169
  if not user_prompt or user_prompt.strip() == "":
170
  return "<div style='padding: 50px; text-align: center;'>Please describe what chart you want to create.</div>"
171
 
172
+ # Temporarily set the API key in environment for this request
173
+ original_key = os.environ.get("DATAWRAPPER_ACCESS_TOKEN")
174
+ os.environ["DATAWRAPPER_ACCESS_TOKEN"] = api_key
175
+
176
  try:
177
  # Show loading message
178
  loading_html = """
 
238
  <div style='padding: 50px; text-align: center; color: red;'>
239
  <h3>❌ Error</h3>
240
  <p>{str(e)}</p>
241
+ <p style='font-size: 0.9em; color: #666;'>Please ensure your CSV is properly formatted and your API key is correct.</p>
242
  </div>
243
  """
244
+ finally:
245
+ # Restore original API key or remove it
246
+ if original_key:
247
+ os.environ["DATAWRAPPER_ACCESS_TOKEN"] = original_key
248
+ elif "DATAWRAPPER_ACCESS_TOKEN" in os.environ:
249
+ del os.environ["DATAWRAPPER_ACCESS_TOKEN"]
250
 
251
  def csv_to_cards_html(csv_text: str) -> str:
252
  """
 
263
  source_url = row.get("source_url", "#")
264
  author = row.get("author", "Inconnu")
265
  published_date = row.get("published_date", "")
266
+ image_url = row.get("image_url", "https://fpoimg.com/800x600?text=Image+not+found")
 
 
 
 
267
 
268
  cards_html += f"""
269
  <div style="background: white; border-radius: 10px; box-shadow: 0 2px 8px rgba(0,0,0,0.1);
 
275
  <p style="margin:0; color:#999; font-size:0.8em;">{published_date}</p>
276
  <a href="{source_url}" target="_blank"
277
  style="display:inline-block; margin-top:8px; font-size:0.9em; color:#1976d2; text-decoration:none;">
278
+ πŸ”— Source
279
  </a>
280
  </div>
281
  </div>
 
310
  """
311
 
312
  try:
313
+ # Classify user intent
314
+ print(f"\n{'='*60}")
315
+ print(f"[SEARCH] User prompt: {user_prompt}")
316
+
317
+ classifier = IntentClassifier()
318
+ classification = classifier.classify(user_prompt)
319
+
320
+ print(f"[INTENT] Type: {classification['intent'].value}")
321
+ print(f"[INTENT] Keywords: {classification['keywords']}")
322
+ print(f"[INTENT] Inferred tags: {classification['tags']}")
323
+ print(f"[INTENT] Short query: {classification['is_short_query']}")
324
+
325
+ # Enhance prompt with intent guidance
326
+ enhanced_prompt = classifier.format_for_vanna(classification)
327
+ full_prompt = f"{user_prompt}\n\n{enhanced_prompt}"
328
+
329
+ print(f"[VANNA] Sending enhanced prompt to Vanna...")
330
+ response = await vanna.ask(full_prompt)
331
+ print(f"[VANNA] Response received: {repr(response)[:200]}...")
332
+ print(f"{'='*60}\n")
333
 
334
  clean_response = response.strip()
335
 
336
+ # Check for empty query results (0 rows returned)
337
+ if "No rows returned" in clean_response or "0 rows" in clean_response.lower():
338
+ return f"""
339
+ <div style='padding: 50px; text-align: center; color: #f0ad4e;'>
340
+ <h3>πŸ” No Results Found</h3>
341
+ <p>Your query was executed successfully, but no posts matched your criteria.</p>
342
+ <p style='margin-top: 15px; font-weight: 600;'>Suggestions:</p>
343
+ <ul style='list-style: none; padding: 0; text-align: left; display: inline-block;'>
344
+ <li>β€’ Try broader keywords (e.g., "visualization" instead of "F1 dataviz")</li>
345
+ <li>β€’ Search by author names (e.g., "New York Times")</li>
346
+ <li>β€’ Use simple terms (e.g., "interactive", "maps")</li>
347
+ </ul>
348
+ <p style='margin-top: 15px; font-style: italic; color: #666; font-size: 0.9em;'>
349
+ <strong>Note:</strong> Most posts are currently being enriched with tags.<br/>
350
+ Keyword search works for all {classification.get('total_posts', '7,000+')} posts in the database.
351
+ </p>
352
+ </div>
353
+ """
354
+
355
+ # Check for errors or warnings
356
+ if clean_response.startswith("⚠️") or clean_response.startswith("❌") or "Aucun CSV détecté" in clean_response:
357
  return f"""
358
  <div style='padding: 50px; text-align: center; color: #d9534f;'>
359
+ <h3>❌ Query Error</h3>
360
+ <p>The AI encountered an issue processing your request.</p>
361
+ <p style='margin-top: 10px; font-size: 0.9em; color: #666;'>{clean_response[:200]}</p>
362
+ <p style='margin-top: 15px;'>Try rephrasing your query or being more specific.</p>
363
  </div>
364
  """
365
 
366
+ # Process CSV response
367
  csv_text = (
368
  clean_response
369
  .strip("```")
 
371
  .replace("CSV", "")
372
  )
373
 
374
+ # Check if response contains CSV data
375
+ if "," not in csv_text or "id,title" not in csv_text.lower():
376
  return f"""
377
  <div style='padding: 50px; text-align: center; color: #d9534f;'>
378
+ <h3>❌ Invalid Response Format</h3>
379
+ <p>The database query didn't return structured data.</p>
380
+ <p style='margin-top: 10px; font-size: 0.9em; color: #666;'>
381
+ This might be a temporary issue. Please try again.
382
+ </p>
383
  </div>
384
  """
385
 
 
387
  return cards_html
388
 
389
  except Exception as e:
390
+ print(f"❌ Exception in search_inspiration_from_database: {str(e)}")
391
+ import traceback
392
+ traceback.print_exc()
393
  return f"""
394
  <div style='padding: 50px; text-align: center; color: red;'>
395
+ <h3>❌ System Error</h3>
396
+ <p style='margin-bottom: 10px;'>An unexpected error occurred:</p>
397
+ <p style='font-family: monospace; font-size: 0.85em; color: #666;'>{str(e)}</p>
398
+ <p style='margin-top: 15px; font-size: 0.9em; color: #666;'>
399
+ Please check the console logs for more details.
400
+ </p>
401
  </div>
402
  """
403
 
 
430
  gr.Markdown("""
431
  # πŸ“Š Viz LLM
432
 
433
+ Discover inspiring visualizations, refine your design ideas, or generate publication-ready charts with AI assistance.
434
+ """)
435
+
436
+ # JavaScript for localStorage persistence
437
+ gr.HTML("""
438
+ <script>
439
+ // Save API key to localStorage when it changes
440
+ function saveApiKeyToStorage(key) {
441
+ if (key && key.trim() !== '') {
442
+ localStorage.setItem('datawrapper_api_key', key);
443
+ }
444
+ }
445
+
446
+ // Load API key from localStorage on page load
447
+ function loadApiKeyFromStorage() {
448
+ return localStorage.getItem('datawrapper_api_key') || '';
449
+ }
450
+
451
+ // Auto-load API key when the page loads
452
+ window.addEventListener('DOMContentLoaded', function() {
453
+ setTimeout(function() {
454
+ const savedKey = loadApiKeyFromStorage();
455
+ if (savedKey) {
456
+ const apiKeyInput = document.querySelector('input[type="password"]');
457
+ if (apiKeyInput) {
458
+ apiKeyInput.value = savedKey;
459
+ // Trigger change event to update Gradio state
460
+ apiKeyInput.dispatchEvent(new Event('input', { bubbles: true }));
461
+ }
462
+ }
463
+ }, 1000);
464
+ });
465
+ </script>
466
  """)
467
 
468
+ # Mode selector buttons (reordered: Inspiration, Refinement, Chart)
469
  with gr.Row():
470
+ inspiration_btn = gr.Button("✨ Inspiration", variant="primary", elem_classes="mode-button")
471
+ ideation_btn = gr.Button("πŸ’‘ Refinement", variant="secondary", elem_classes="mode-button")
472
+ chart_gen_btn = gr.Button("πŸ“Š Chart", variant="secondary", elem_classes="mode-button")
473
+
474
+
475
+ # Inspiration Mode: Search interface (shown by default)
476
+ with gr.Column(visible=True) as inspiration_container:
477
+ with gr.Row():
478
+ inspiration_prompt_input = gr.Textbox(
479
+ placeholder="Search for inspiration (e.g., 'F1', 'interactive maps')...",
480
+ show_label=False,
481
+ scale=4,
482
+ container=False
483
+ )
484
+ inspiration_search_btn = gr.Button("πŸ” Search", variant="primary", scale=1)
485
 
486
+ inspiration_cards_html = gr.HTML("")
487
 
488
+ # Refinement Mode: Chat interface (hidden by default, wrapped in Column)
489
+ with gr.Column(visible=False) as ideation_container:
490
  ideation_interface = gr.ChatInterface(
491
  fn=recommend_stream,
492
  type="messages",
 
503
 
504
  # Chart Generation Mode: Chart controls and output (hidden by default)
505
  with gr.Column(visible=False) as chart_gen_container:
506
+ gr.Markdown("### Chart Generator")
507
+
508
+ # API Key Input (collapsible)
509
+ with gr.Accordion("πŸ”‘ Datawrapper API Key", open=False):
510
+ gr.Markdown("""
511
+ Enter your Datawrapper API key to generate charts. Your key is stored in your browser and persists across sessions.
512
+
513
+ **Get your key**: [Datawrapper Account Settings](https://app.datawrapper.de/account/api-tokens)
514
+ """)
515
+
516
+ # Warning about permissions
517
+ gr.HTML("""
518
+ <div style="background: #fff3cd; border: 1px solid #ffc107; border-radius: 5px; padding: 12px; margin: 10px 0;">
519
+ <strong>⚠️ Important:</strong> When creating your API key, toggle <strong>ALL permissions</strong> (Read & Write for Charts, Tables, Folders, etc.) otherwise chart generation will fail.
520
+ </div>
521
+ """)
522
+
523
+ api_key_input = gr.Textbox(
524
+ label="API Key",
525
+ placeholder="Paste your Datawrapper API key here...",
526
+ type="password",
527
+ value=""
528
+ )
529
+
530
+ api_key_status = gr.Markdown("⚠️ Status: No API key provided")
531
+
532
  csv_upload = gr.File(
533
  label="πŸ“ Upload CSV File",
534
  file_types=[".csv"],
 
548
  label="Generated Chart"
549
  )
550
 
551
+ # API key state management
552
+ api_key_state = gr.State(value="")
 
 
 
 
 
 
 
 
553
 
554
+ def validate_api_key(api_key: str) -> tuple[str, str]:
555
+ """Validate and store API key"""
556
+ if not api_key or api_key.strip() == "":
557
+ return "", "⚠️ Status: No API key provided"
558
+
559
+ # Basic validation (check format)
560
+ if len(api_key) < 20:
561
+ return "", "❌ Status: Invalid API key format (too short)"
562
+
563
+ # Key looks valid - it will be saved to localStorage via JavaScript
564
+ masked_key = f"...{api_key[-6:]}" if len(api_key) > 6 else "***"
565
+ return api_key, f"βœ… Status: API key saved to browser storage (ends with {masked_key})"
566
+
567
+ # Mode switching functions (updated for new order: Inspiration, Refinement, Chart)
568
+ def switch_to_inspiration():
569
+ return [
570
+ gr.update(variant="primary"), # inspiration_btn
571
+ gr.update(variant="secondary"), # ideation_btn
572
+ gr.update(variant="secondary"), # chart_gen_btn
573
+ gr.update(visible=True), # inspiration_container
574
+ gr.update(visible=False), # ideation_container
575
+ gr.update(visible=False), # chart_gen_container
576
+ ]
577
 
 
578
  def switch_to_ideation():
579
  return [
580
+ gr.update(variant="secondary"), # inspiration_btn
581
  gr.update(variant="primary"), # ideation_btn
582
  gr.update(variant="secondary"), # chart_gen_btn
583
+ gr.update(visible=False), # inspiration_container
584
  gr.update(visible=True), # ideation_container
585
  gr.update(visible=False), # chart_gen_container
 
586
  ]
587
 
588
  def switch_to_chart_gen():
589
  return [
590
+ gr.update(variant="secondary"), # inspiration_btn
591
  gr.update(variant="secondary"), # ideation_btn
592
  gr.update(variant="primary"), # chart_gen_btn
593
+ gr.update(visible=False), # inspiration_container
594
  gr.update(visible=False), # ideation_container
595
  gr.update(visible=True), # chart_gen_container
 
596
  ]
597
 
598
+ # Wire up mode switching (updated order: inspiration, ideation, chart)
599
+ inspiration_btn.click(
600
+ fn=switch_to_inspiration,
601
+ inputs=[],
602
+ outputs=[inspiration_btn, ideation_btn, chart_gen_btn, inspiration_container, ideation_container, chart_gen_container]
603
+ )
 
 
 
604
 
 
605
  ideation_btn.click(
606
  fn=switch_to_ideation,
607
  inputs=[],
608
+ outputs=[inspiration_btn, ideation_btn, chart_gen_btn, inspiration_container, ideation_container, chart_gen_container]
609
  )
610
 
611
  chart_gen_btn.click(
612
  fn=switch_to_chart_gen,
613
  inputs=[],
614
+ outputs=[inspiration_btn, ideation_btn, chart_gen_btn, inspiration_container, ideation_container, chart_gen_container]
615
  )
616
 
617
+ # Connect API key validation and localStorage save
618
+ api_key_input.change(
619
+ fn=validate_api_key,
620
+ inputs=[api_key_input],
621
+ outputs=[api_key_state, api_key_status],
622
+ js="(key) => { saveApiKeyToStorage(key); return key; }"
623
  )
624
 
625
+ # Generate chart when button is clicked (now with API key)
626
  generate_chart_btn.click(
627
  fn=generate_chart_from_csv,
628
+ inputs=[csv_upload, chart_prompt_input, api_key_state],
629
  outputs=[chart_output]
630
  )
631
 
632
+ # Search inspiration with loading state
633
+ def search_with_loading(prompt):
634
+ """Wrapper to show loading state"""
635
+ if not prompt or not prompt.strip():
636
+ return """
637
+ <div style='padding: 50px; text-align: center;'>
638
+ Please enter a search query.
639
+ </div>
640
+ """
641
+ # Show loading immediately (Gradio will display this first)
642
+ yield """
643
+ <div style='padding: 50px; text-align: center;'>
644
+ <div style='font-size: 2em; margin-bottom: 20px;'>πŸ”</div>
645
+ <h3>Searching database...</h3>
646
+ <p style='color: #666;'>Analyzing your query and generating SQL...</p>
647
+ </div>
648
+ """
649
+ # Run the actual search
650
+ import asyncio
651
+ result = asyncio.run(search_inspiration_from_database(prompt))
652
+ yield result
653
+
654
  inspiration_search_btn.click(
655
+ fn=search_with_loading,
656
  inputs=[inspiration_prompt_input],
657
  outputs=[inspiration_cards_html]
658
  )
 
661
  gr.Markdown("""
662
  ### About Viz LLM
663
 
664
+ **Inspiration**: Discover curated examples of data visualizations and information graphics from publications worldwide. Search by keyword, topic, or author.
665
+
666
+ **Refinement**: Get design recommendations based on research papers, design principles, and examples from the field of information graphics and data visualization.
667
 
668
+ **Chart**: Upload your CSV data and describe your visualization goal. The AI will analyze your data, select the optimal chart type, and generate a publication-ready chart using Datawrapper.
 
 
669
 
670
  **Credits:** Special thanks to the researchers whose work informed this model: Robert Kosara, Edward Segel, Jeffrey Heer, Matthew Conlen, John Maeda, Kennedy Elliott, Scott McCloud, and many others.
671
 
 
674
  **Usage Limits:** This service is limited to 20 queries per day per user to manage costs. Responses are optimized for English.
675
 
676
  <div style="text-align: center; margin-top: 20px; opacity: 0.6; font-size: 0.9em;">
677
+ Embeddings: Jina-CLIP-v2 | Charts: Datawrapper API | Database: Nuanced
678
  </div>
679
  """)
680
 
681
  # Launch configuration
682
  if __name__ == "__main__":
683
+ # Check for required environment variables (Datawrapper key now user-provided)
684
+ required_vars = ["SUPABASE_URL", "SUPABASE_KEY", "HF_TOKEN"]
685
  missing_vars = [var for var in required_vars if not os.getenv(var)]
686
 
687
  if missing_vars:
688
  print(f"⚠️ Warning: Missing environment variables: {', '.join(missing_vars)}")
689
  print("Please set these in your .env file or as environment variables")
 
 
690
 
691
  # Launch the app
692
  demo.launch(
src/query_intent_classifier.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Query Intent Classifier for Hybrid Search
3
+
4
+ Analyzes user queries to determine the best search strategy:
5
+ - keyword: Full-text search on title/author/provider (works for all posts)
6
+ - tag: Tag-based search (works only for tagged posts)
7
+ - hybrid: Try both approaches
8
+ """
9
+
10
+ import re
11
+ from typing import Dict, List
12
+ from enum import Enum
13
+
14
+
15
+ class QueryIntent(Enum):
16
+ KEYWORD = "keyword"
17
+ TAG = "tag"
18
+ HYBRID = "hybrid"
19
+
20
+
21
+ class IntentClassifier:
22
+ """
23
+ Classifies user queries and extracts relevant search parameters.
24
+ """
25
+
26
+ # Keywords that suggest tag search
27
+ TAG_INDICATORS = ["tagged", "category", "topic", "theme", "type", "about"]
28
+
29
+ # Common keywords to expand for better matching
30
+ KEYWORD_EXPANSIONS = {
31
+ "f1": ["f1", "formula 1", "formula one", "racing"],
32
+ "dataviz": ["dataviz", "data visualization", "visualization", "chart", "graph"],
33
+ "interactive": ["interactive", "interaction", "explore"],
34
+ "map": ["map", "maps", "mapping", "geographic", "geo"],
35
+ "nyt": ["new york times", "nyt", "ny times"],
36
+ }
37
+
38
+ def __init__(self):
39
+ pass
40
+
41
+ def classify(self, user_prompt: str) -> Dict:
42
+ """
43
+ Classify user intent and extract search parameters.
44
+
45
+ Args:
46
+ user_prompt: The user's search query
47
+
48
+ Returns:
49
+ Dict with:
50
+ - intent: QueryIntent enum
51
+ - keywords: List of keywords to search
52
+ - tags: List of potential tags to search
53
+ - original_query: Original user prompt
54
+ """
55
+ prompt_lower = user_prompt.lower().strip()
56
+
57
+ # Detect intent
58
+ intent = self._detect_intent(prompt_lower)
59
+
60
+ # Extract keywords
61
+ keywords = self._extract_keywords(prompt_lower)
62
+
63
+ # Infer potential tags
64
+ tags = self._infer_tags(prompt_lower, keywords)
65
+
66
+ return {
67
+ "intent": intent,
68
+ "keywords": keywords,
69
+ "tags": tags,
70
+ "original_query": user_prompt,
71
+ "is_short_query": len(prompt_lower.split()) <= 3
72
+ }
73
+
74
+ def _detect_intent(self, prompt: str) -> QueryIntent:
75
+ """
76
+ Determine if user wants tag search, keyword search, or hybrid.
77
+ """
78
+ # Check for tag indicators
79
+ has_tag_indicator = any(indicator in prompt for indicator in self.TAG_INDICATORS)
80
+
81
+ # Short queries (1-3 words) should try hybrid approach
82
+ word_count = len(prompt.split())
83
+
84
+ if has_tag_indicator:
85
+ return QueryIntent.TAG
86
+ elif word_count <= 3:
87
+ # Short queries: try both tag and keyword search
88
+ return QueryIntent.HYBRID
89
+ else:
90
+ # Longer natural language queries: keyword search first
91
+ return QueryIntent.KEYWORD
92
+
93
+ def _extract_keywords(self, prompt: str) -> List[str]:
94
+ """
95
+ Extract meaningful keywords from the prompt.
96
+ """
97
+ # Remove common stop words
98
+ stop_words = {
99
+ "show", "me", "find", "get", "search", "for", "the", "a", "an",
100
+ "with", "about", "of", "in", "on", "at", "to", "from", "by",
101
+ "what", "where", "when", "who", "how", "is", "are", "was", "were"
102
+ }
103
+
104
+ # Split and clean
105
+ words = re.findall(r'\b\w+\b', prompt.lower())
106
+ # Allow 2-character words like "F1", "AI", "3D"
107
+ keywords = [w for w in words if w not in stop_words and len(w) >= 2]
108
+
109
+ # Expand known keywords
110
+ expanded_keywords = []
111
+ for keyword in keywords:
112
+ if keyword in self.KEYWORD_EXPANSIONS:
113
+ expanded_keywords.extend(self.KEYWORD_EXPANSIONS[keyword])
114
+ else:
115
+ expanded_keywords.append(keyword)
116
+
117
+ # Remove duplicates while preserving order
118
+ return list(dict.fromkeys(expanded_keywords))
119
+
120
+ def _infer_tags(self, prompt: str, keywords: List[str]) -> List[str]:
121
+ """
122
+ Infer potential tag names from keywords.
123
+
124
+ Since we have limited tags in the database, we map common terms
125
+ to likely tag names.
126
+ """
127
+ # Common tag mappings based on the database
128
+ tag_mappings = {
129
+ "f1": ["f1", "racing", "motorsport", "sports"],
130
+ "formula": ["f1", "racing", "motorsport"],
131
+ "racing": ["racing", "motorsport", "f1"],
132
+ "dataviz": ["dataviz", "visualization"],
133
+ "visualization": ["dataviz", "visualization"],
134
+ "interactive": ["interactive"],
135
+ "map": ["maps", "geographic"],
136
+ "maps": ["maps", "geographic"],
137
+ "math": ["mathematics", "statistics"],
138
+ "statistics": ["statistics", "mathematics"],
139
+ "africa": ["africa", "kenya", "tanzania"],
140
+ "sustainability": ["sustainability", "regreening"],
141
+ "documentary": ["documentary", "cinematic"],
142
+ "education": ["students", "researchers"],
143
+ }
144
+
145
+ inferred_tags = []
146
+ for keyword in keywords:
147
+ if keyword in tag_mappings:
148
+ inferred_tags.extend(tag_mappings[keyword])
149
+
150
+ # If no specific mapping, use the keyword as-is
151
+ if not inferred_tags:
152
+ inferred_tags = keywords[:3] # Limit to top 3 keywords
153
+
154
+ # Remove duplicates
155
+ return list(dict.fromkeys(inferred_tags))
156
+
157
+ def format_for_vanna(self, classification: Dict) -> str:
158
+ """
159
+ Format the classification result for Vanna's prompt.
160
+
161
+ Returns a string that guides Vanna's SQL generation.
162
+ """
163
+ intent = classification["intent"]
164
+ keywords = classification["keywords"]
165
+ tags = classification["tags"]
166
+
167
+ if intent == QueryIntent.KEYWORD:
168
+ return f"""
169
+ Search using KEYWORD approach:
170
+ - Search terms: {', '.join(keywords)}
171
+ - Search in: posts.title, posts.author, providers.name
172
+ - Use ILIKE with wildcards for flexible matching
173
+ - Do not filter by tags (most posts are not tagged yet)
174
+ """
175
+
176
+ elif intent == QueryIntent.TAG:
177
+ return f"""
178
+ Search using TAG approach:
179
+ - Tag names: {', '.join(tags)}
180
+ - Use LOWER() for case-insensitive matching
181
+ - Join with post_tags and tags tables
182
+ - Note: Only a few posts are tagged, results may be limited
183
+ """
184
+
185
+ else: # HYBRID
186
+ return f"""
187
+ Search using HYBRID approach:
188
+ - Try tags first: {', '.join(tags)}
189
+ - Fall back to keywords: {', '.join(keywords)}
190
+ - Use OR logic: tag matches OR keyword matches in title/author
191
+ - This maximizes results since most posts are not tagged yet
192
+
193
+ Recommended SQL pattern:
194
+ SELECT DISTINCT p.id, p.title, p.source_url, p.author, p.published_date, p.image_url, p.type
195
+ FROM posts p
196
+ LEFT JOIN post_tags pt ON p.id = pt.post_id
197
+ LEFT JOIN tags t ON pt.tag_id = t.id
198
+ LEFT JOIN providers pr ON p.provider_id = pr.id
199
+ WHERE
200
+ LOWER(t.name) = ANY(ARRAY[{', '.join(f"'{tag}'" for tag in tags)}])
201
+ OR LOWER(p.title) LIKE ANY(ARRAY[{', '.join(f"'%{kw}%'" for kw in keywords)}])
202
+ OR LOWER(p.author) LIKE ANY(ARRAY[{', '.join(f"'%{kw}%'" for kw in keywords)}])
203
+ OR LOWER(pr.name) LIKE ANY(ARRAY[{', '.join(f"'%{kw}%'" for kw in keywords)}])
204
+ ORDER BY p.published_date DESC NULLS LAST
205
+ LIMIT 9
206
+ """
207
+
208
+
209
+ # Convenience function
210
+ def classify_query(user_prompt: str) -> Dict:
211
+ """
212
+ Classify a user query and return search parameters.
213
+ """
214
+ classifier = IntentClassifier()
215
+ return classifier.classify(user_prompt)
216
+
217
+
218
+ # Example usage
219
+ if __name__ == "__main__":
220
+ # Test cases
221
+ test_queries = [
222
+ "F1",
223
+ "Show me F1 content",
224
+ "interactive visualizations",
225
+ "New York Times articles",
226
+ "content tagged with dataviz",
227
+ "recent sustainability projects in Africa",
228
+ ]
229
+
230
+ classifier = IntentClassifier()
231
+
232
+ for query in test_queries:
233
+ result = classifier.classify(query)
234
+ print(f"\nQuery: '{query}'")
235
+ print(f"Intent: {result['intent'].value}")
236
+ print(f"Keywords: {result['keywords']}")
237
+ print(f"Tags: {result['tags']}")
238
+ print(f"Short query: {result['is_short_query']}")
src/vanna.py CHANGED
@@ -55,9 +55,6 @@ class CustomSQLSystemPromptBuilder(SystemPromptBuilder):
55
  "- Never use SELECT *\n"
56
  "- Prefer window functions over subqueries when possible\n"
57
  "- Always include a LIMIT for exploratory queries\n"
58
- "- Exclude posts where provider = 'SND'\n"
59
- "- Exclude posts where type = 'resource'\n"
60
- "- Exclude posts where type = 'insight'\n"
61
  "- Format dates and numbers for readability\n"
62
  )
63
 
@@ -106,15 +103,32 @@ class CustomSQLSystemPromptBuilder(SystemPromptBuilder):
106
  # ======================
107
  prompt += (
108
  "\n## Business Logic\n"
109
- "- Providers named 'SND' must always be excluded.\n"
110
  "- A query mentioning an organization (e.g., 'New York Times') should search both `posts.author` and `providers.name`.\n"
111
- "- By default, only posts with `type = 'spotlight'` are returned.\n"
112
- "- Posts of type `resource` or `insight` are excluded unless explicitly requested.\n"
113
  "- Tags link posts to specific themes or disciplines.\n"
114
  "- A single post may have multiple tags, awards, or categories.\n"
115
  "- If the user mentions a year (e.g., 'in 2021'), filter with `EXTRACT(YEAR FROM published_date) = 2021`.\n"
116
  "- If the user says 'recently', filter posts from the last 90 days.\n"
117
  "- Always limit exploratory results to 9 rows.\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  )
119
 
120
  # ======================
@@ -145,21 +159,30 @@ class CustomSQLSystemPromptBuilder(SystemPromptBuilder):
145
  # ======================
146
  prompt += (
147
  "\n## Example Interactions\n"
148
- "User: 'Show me posts related to 3D'\n"
149
- "Assistant: [call run_sql with \"SELECT p.id, p.title, p.source_url, p.author, p.published_date, p.image_url, p.type "
150
  "FROM posts p "
151
- "JOIN post_tags pt ON p.id = pt.post_id "
152
- "JOIN tags t ON pt.tag_id = t.id "
153
- "JOIN providers pr ON p.provider_id = pr.id "
154
- "WHERE t.name ILIKE '%3D%' AND pr.name != 'SND' AND p.type = 'spotlight' "
155
- "LIMIT 9;\"]\n"
 
 
156
  "\nUser: 'Show me posts from The New York Times'\n"
157
- "Assistant: [call run_sql with \"SELECT p.id, p.title, p.source_url, p.author, p.published_date, p.image_url, p.type "
158
  "FROM posts p "
159
- "LEFT JOIN providers pr ON pr.id = p.provider_id "
160
- "WHERE LOWER(p.author) LIKE '%new york times%' OR LOWER(pr.name) LIKE '%new york times%' "
161
- "AND pr.name != 'SND' AND p.type = 'spotlight' "
162
- "LIMIT 9;\"]\n"
 
 
 
 
 
 
 
163
  )
164
 
165
  # ======================
@@ -167,8 +190,6 @@ class CustomSQLSystemPromptBuilder(SystemPromptBuilder):
167
  # ======================
168
  prompt += (
169
  "\nIMPORTANT:\n"
170
- "- Always exclude posts with provider = 'SND'.\n"
171
- "- Always exclude posts with type = 'resource' or 'insight'.\n"
172
  "- Always return **only the raw CSV result** β€” no explanations, no JSON, no commentary.\n"
173
  "- Stop tool execution once the query result is obtained.\n"
174
  )
@@ -197,8 +218,8 @@ class VannaComponent:
197
  db_tool = RunSqlTool(sql_runner=self.sql_runner)
198
 
199
  agent_memory = DemoAgentMemory(max_items=1000)
200
- save_memory_tool = SaveQuestionToolArgsTool(agent_memory)
201
- search_memory_tool = SearchSavedCorrectToolUsesTool(agent_memory)
202
 
203
  self.user_resolver = SimpleUserResolver()
204
 
@@ -211,32 +232,46 @@ class VannaComponent:
211
  llm_service=llm,
212
  tool_registry=tools,
213
  user_resolver=self.user_resolver,
 
214
  system_prompt_builder=CustomSQLSystemPromptBuilder("CoJournalist", self.sql_runner),
215
- config=AgentConfig(stream_responses=False, max_tool_iterations=1)
216
  )
217
 
218
  async def ask(self, prompt_for_llm: str):
219
  ctx = RequestContext()
220
- print(f"πŸ™‹ Prompt sent to LLM: {prompt_for_llm}")
 
 
221
 
222
  final_text = ""
223
  seen_texts = set()
 
 
224
 
225
  async for component in self.agent.send_message(request_context=ctx, message=prompt_for_llm):
226
  simple = getattr(component, "simple_component", None)
227
  text = getattr(simple, "text", "") if simple else ""
228
  if text and text not in seen_texts:
229
- print(f"πŸ’¬ LLM says (part): {text[:200]}...")
230
  final_text += text + "\n"
231
  seen_texts.add(text)
232
 
233
  sql_query = getattr(component, "sql", None)
234
  if sql_query:
235
- print(f"🧾 SQL Query Generated: {sql_query}")
 
 
 
 
236
 
237
  metadata = getattr(component, "metadata", None)
238
  if metadata:
239
- print(f"πŸ“‹ Metadata: {metadata}")
 
 
 
 
 
240
 
241
  component_type = getattr(component, "type", None)
242
  if component_type:
@@ -245,16 +280,36 @@ class VannaComponent:
245
  match = re.search(r"query_results_[\w-]+\.csv", final_text)
246
  if match:
247
  filename = match.group(0)
248
- folder = "513935c4d2db2d2d"
 
 
 
249
  full_path = os.path.join(folder, filename)
250
 
 
 
 
 
 
 
 
251
  if os.path.exists(full_path):
252
- print(f"πŸ“‚ Reading result file: {full_path}")
253
  with open(full_path, "r", encoding="utf-8") as f:
254
  csv_data = f.read().strip()
255
- print("πŸ€– Response sent to user (from file):", csv_data[:300])
 
256
  return csv_data
257
  else:
258
- print(f"⚠️ File not found: {full_path}")
259
-
 
 
 
 
 
 
 
 
 
260
  return final_text
 
55
  "- Never use SELECT *\n"
56
  "- Prefer window functions over subqueries when possible\n"
57
  "- Always include a LIMIT for exploratory queries\n"
 
 
 
58
  "- Format dates and numbers for readability\n"
59
  )
60
 
 
103
  # ======================
104
  prompt += (
105
  "\n## Business Logic\n"
 
106
  "- A query mentioning an organization (e.g., 'New York Times') should search both `posts.author` and `providers.name`.\n"
107
+ "- Return all post types (spotlight, resource, insight) unless the user specifies otherwise.\n"
 
108
  "- Tags link posts to specific themes or disciplines.\n"
109
  "- A single post may have multiple tags, awards, or categories.\n"
110
  "- If the user mentions a year (e.g., 'in 2021'), filter with `EXTRACT(YEAR FROM published_date) = 2021`.\n"
111
  "- If the user says 'recently', filter posts from the last 90 days.\n"
112
  "- Always limit exploratory results to 9 rows.\n"
113
+ "\n"
114
+ "## CRITICAL: Search Strategy\n"
115
+ "**IMPORTANT**: Only 3 posts currently have tags. Most posts (7,245+) are NOT tagged yet.\n"
116
+ "\n"
117
+ "**Hybrid Search Approach (RECOMMENDED)**:\n"
118
+ "- ALWAYS use a hybrid approach combining tag search AND keyword search with OR logic.\n"
119
+ "- Use LEFT JOINs for tags (not INNER JOIN) so untagged posts are included.\n"
120
+ "\n"
121
+ "**Keyword Matching - Use PostgreSQL Regex for Exact Word Boundaries**:\n"
122
+ "- Use ~* operator for case-insensitive regex matching\n"
123
+ "- Use \\m and \\M for word boundaries (start and end of word)\n"
124
+ "- Pattern: column ~* '\\\\mkeyword\\\\M'\n"
125
+ "- Example: p.title ~* '\\\\mf1\\\\M' matches 'F1' but NOT 'profile' or 'if'\n"
126
+ "- This ensures exact word matching, not substring matching\n"
127
+ "\n"
128
+ "**When to use tag-only search**: Only if user explicitly mentions 'tagged with' or 'tag:'.\n"
129
+ "**When to use keyword-only search**: For author/organization names, or when tags are not relevant.\n"
130
+ "\n"
131
+ "This ensures maximum result coverage while the database is being enriched with tags.\n"
132
  )
133
 
134
  # ======================
 
159
  # ======================
160
  prompt += (
161
  "\n## Example Interactions\n"
162
+ "User: 'F1' or 'Show me F1 content'\n"
163
+ "Assistant: [call run_sql with \"SELECT DISTINCT p.id, p.title, p.source_url, p.author, p.published_date, p.image_url, p.type "
164
  "FROM posts p "
165
+ "LEFT JOIN post_tags pt ON p.id = pt.post_id "
166
+ "LEFT JOIN tags t ON pt.tag_id = t.id "
167
+ "LEFT JOIN providers pr ON p.provider_id = pr.id "
168
+ "WHERE t.name ~* '\\\\mf1\\\\M' OR t.name ~* '\\\\mformula\\\\M' "
169
+ "OR p.title ~* '\\\\mf1\\\\M' OR p.title ~* '\\\\mformula\\\\M' "
170
+ "OR p.author ~* '\\\\mf1\\\\M' "
171
+ "ORDER BY p.published_date DESC NULLS LAST LIMIT 9;\"]\n"
172
  "\nUser: 'Show me posts from The New York Times'\n"
173
+ "Assistant: [call run_sql with \"SELECT DISTINCT p.id, p.title, p.source_url, p.author, p.published_date, p.image_url, p.type "
174
  "FROM posts p "
175
+ "LEFT JOIN providers pr ON p.provider_id = pr.id "
176
+ "WHERE p.author ~* '\\\\mnew\\\\M.*\\\\myork\\\\M.*\\\\mtimes\\\\M' OR pr.name ~* '\\\\mnew\\\\M.*\\\\myork\\\\M.*\\\\mtimes\\\\M' "
177
+ "ORDER BY p.published_date DESC NULLS LAST LIMIT 9;\"]\n"
178
+ "\nUser: 'interactive visualizations'\n"
179
+ "Assistant: [call run_sql with \"SELECT DISTINCT p.id, p.title, p.source_url, p.author, p.published_date, p.image_url, p.type "
180
+ "FROM posts p "
181
+ "LEFT JOIN post_tags pt ON p.id = pt.post_id "
182
+ "LEFT JOIN tags t ON pt.tag_id = t.id "
183
+ "WHERE t.name ~* '\\\\minteractive\\\\M' OR p.title ~* '\\\\minteractive\\\\M' "
184
+ "OR p.title ~* '\\\\mvisualization\\\\M' OR t.name ~* '\\\\mdataviz\\\\M' "
185
+ "ORDER BY p.published_date DESC NULLS LAST LIMIT 9;\"]\n"
186
  )
187
 
188
  # ======================
 
190
  # ======================
191
  prompt += (
192
  "\nIMPORTANT:\n"
 
 
193
  "- Always return **only the raw CSV result** β€” no explanations, no JSON, no commentary.\n"
194
  "- Stop tool execution once the query result is obtained.\n"
195
  )
 
218
  db_tool = RunSqlTool(sql_runner=self.sql_runner)
219
 
220
  agent_memory = DemoAgentMemory(max_items=1000)
221
+ save_memory_tool = SaveQuestionToolArgsTool()
222
+ search_memory_tool = SearchSavedCorrectToolUsesTool()
223
 
224
  self.user_resolver = SimpleUserResolver()
225
 
 
232
  llm_service=llm,
233
  tool_registry=tools,
234
  user_resolver=self.user_resolver,
235
+ agent_memory=agent_memory,
236
  system_prompt_builder=CustomSQLSystemPromptBuilder("CoJournalist", self.sql_runner),
237
+ config=AgentConfig(stream_responses=False, max_tool_iterations=3)
238
  )
239
 
240
  async def ask(self, prompt_for_llm: str):
241
  ctx = RequestContext()
242
+ print(f"\n{'='*80}")
243
+ print(f"πŸ™‹ User Query: {prompt_for_llm}")
244
+ print(f"{'='*80}\n")
245
 
246
  final_text = ""
247
  seen_texts = set()
248
+ query_executed = False
249
+ result_row_count = 0
250
 
251
  async for component in self.agent.send_message(request_context=ctx, message=prompt_for_llm):
252
  simple = getattr(component, "simple_component", None)
253
  text = getattr(simple, "text", "") if simple else ""
254
  if text and text not in seen_texts:
255
+ print(f"πŸ’¬ LLM Response: {text[:300]}...")
256
  final_text += text + "\n"
257
  seen_texts.add(text)
258
 
259
  sql_query = getattr(component, "sql", None)
260
  if sql_query:
261
+ query_executed = True
262
+ print(f"\n🧾 SQL Query Generated:")
263
+ print(f"{'-'*80}")
264
+ print(f"{sql_query}")
265
+ print(f"{'-'*80}\n")
266
 
267
  metadata = getattr(component, "metadata", None)
268
  if metadata:
269
+ print(f"πŸ“‹ Query Metadata: {metadata}")
270
+ result_row_count = metadata.get("row_count", 0)
271
+ if result_row_count == 0:
272
+ print(f"⚠️ Query returned 0 rows - no data matched the criteria")
273
+ else:
274
+ print(f"βœ… Query returned {result_row_count} rows")
275
 
276
  component_type = getattr(component, "type", None)
277
  if component_type:
 
280
  match = re.search(r"query_results_[\w-]+\.csv", final_text)
281
  if match:
282
  filename = match.group(0)
283
+ # Calculate the user-specific folder based on the default user ID
284
+ import hashlib
285
+ user_hash = hashlib.sha256("[email protected]".encode()).hexdigest()[:16]
286
+ folder = user_hash
287
  full_path = os.path.join(folder, filename)
288
 
289
+ print(f"\nπŸ“ Looking for CSV file: {full_path}")
290
+
291
+ # Create folder if it doesn't exist
292
+ if not os.path.exists(folder):
293
+ print(f"πŸ“‚ Creating user directory: {folder}")
294
+ os.makedirs(folder, exist_ok=True)
295
+
296
  if os.path.exists(full_path):
297
+ print(f"βœ… Found CSV file, reading contents...")
298
  with open(full_path, "r", encoding="utf-8") as f:
299
  csv_data = f.read().strip()
300
+ print(f"πŸ“Š CSV Data Preview: {csv_data[:200]}...")
301
+ print(f"{'='*80}\n")
302
  return csv_data
303
  else:
304
+ print(f"❌ CSV file not found at: {full_path}")
305
+ # List files in the directory to help debug
306
+ if os.path.exists(folder):
307
+ files = os.listdir(folder)
308
+ print(f"πŸ“‚ Files in {folder}: {files}")
309
+
310
+ print(f"\n{'='*80}")
311
+ if not query_executed:
312
+ print(f"⚠️ No SQL query was executed by the LLM")
313
+ print(f"πŸ“€ Returning final response to user")
314
+ print(f"{'='*80}\n")
315
  return final_text
src/vanna_query_functions.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Vanna Query Function Templates
3
+
4
+ Defines SQL templates for different search strategies.
5
+ These are used by Vanna to generate accurate, performant SQL queries.
6
+ """
7
+
8
+ from typing import Dict, List
9
+
10
+
11
+ class QueryFunctions:
12
+ """
13
+ Collection of SQL query templates for different search strategies.
14
+ """
15
+
16
+ @staticmethod
17
+ def keyword_search(keywords: List[str], limit: int = 9) -> str:
18
+ """
19
+ Full-text keyword search across title, author, and provider.
20
+
21
+ Works for all posts in the database (7,248 posts).
22
+
23
+ Args:
24
+ keywords: List of keywords to search for
25
+ limit: Maximum number of results
26
+
27
+ Returns:
28
+ SQL query string
29
+ """
30
+ # Build regex conditions for each keyword with word boundaries
31
+ # Use PostgreSQL ~* operator for case-insensitive regex matching
32
+ # \m and \M are word boundary markers (start/end of word)
33
+ keyword_conditions = []
34
+ for keyword in keywords:
35
+ keyword_lower = keyword.lower()
36
+ # Escape special regex characters
37
+ keyword_escaped = keyword_lower.replace('\\', '\\\\').replace('.', '\\.').replace('+', '\\+')
38
+ keyword_conditions.append(f"""
39
+ (p.title ~* '\\m{keyword_escaped}\\M'
40
+ OR p.author ~* '\\m{keyword_escaped}\\M'
41
+ OR pr.name ~* '\\m{keyword_escaped}\\M')
42
+ """)
43
+
44
+ where_clause = " OR ".join(keyword_conditions)
45
+
46
+ return f"""
47
+ SELECT DISTINCT
48
+ p.id,
49
+ p.title,
50
+ p.source_url,
51
+ p.author,
52
+ p.published_date,
53
+ p.image_url,
54
+ p.type,
55
+ pr.name as provider_name
56
+ FROM posts p
57
+ LEFT JOIN providers pr ON p.provider_id = pr.id
58
+ WHERE {where_clause}
59
+ ORDER BY p.published_date DESC NULLS LAST
60
+ LIMIT {limit};
61
+ """
62
+
63
+ @staticmethod
64
+ def tag_search(tags: List[str], limit: int = 9) -> str:
65
+ """
66
+ Tag-based search.
67
+
68
+ Currently works for only 3 posts with tags.
69
+ As more posts are tagged, this will return more results.
70
+
71
+ Args:
72
+ tags: List of tag names to search for
73
+ limit: Maximum number of results
74
+
75
+ Returns:
76
+ SQL query string
77
+ """
78
+ # Format tag array for SQL
79
+ tags_lower = [f"'{tag.lower()}'" for tag in tags]
80
+ tags_array = f"ARRAY[{', '.join(tags_lower)}]"
81
+
82
+ return f"""
83
+ SELECT DISTINCT
84
+ p.id,
85
+ p.title,
86
+ p.source_url,
87
+ p.author,
88
+ p.published_date,
89
+ p.image_url,
90
+ p.type,
91
+ pr.name as provider_name,
92
+ string_agg(DISTINCT t.name, ', ') as tags
93
+ FROM posts p
94
+ JOIN post_tags pt ON p.id = pt.post_id
95
+ JOIN tags t ON pt.tag_id = t.id
96
+ LEFT JOIN providers pr ON p.provider_id = pr.id
97
+ WHERE LOWER(t.name) = ANY({tags_array})
98
+ GROUP BY p.id, p.title, p.source_url, p.author, p.published_date, p.image_url, p.type, pr.name
99
+ ORDER BY p.published_date DESC NULLS LAST
100
+ LIMIT {limit};
101
+ """
102
+
103
+ @staticmethod
104
+ def hybrid_search(keywords: List[str], tags: List[str], limit: int = 9) -> str:
105
+ """
106
+ Hybrid search combining tags AND keywords.
107
+
108
+ Best of both worlds:
109
+ - Finds tagged posts (currently 3)
110
+ - Falls back to keyword search for untagged posts (7,245)
111
+
112
+ Args:
113
+ keywords: List of keywords to search for
114
+ tags: List of tag names to search for
115
+ limit: Maximum number of results
116
+
117
+ Returns:
118
+ SQL query string
119
+ """
120
+ # Build tag conditions
121
+ tags_lower = [f"'{tag.lower()}'" for tag in tags]
122
+ tags_array = f"ARRAY[{', '.join(tags_lower)}]"
123
+
124
+ # Build regex keyword conditions with word boundaries
125
+ keyword_conditions = []
126
+ for keyword in keywords:
127
+ keyword_lower = keyword.lower()
128
+ # Escape special regex characters
129
+ keyword_escaped = keyword_lower.replace('\\', '\\\\').replace('.', '\\.').replace('+', '\\+')
130
+ keyword_conditions.append(f"""
131
+ (p.title ~* '\\m{keyword_escaped}\\M'
132
+ OR p.author ~* '\\m{keyword_escaped}\\M'
133
+ OR pr.name ~* '\\m{keyword_escaped}\\M')
134
+ """)
135
+
136
+ keyword_where = " OR ".join(keyword_conditions)
137
+
138
+ return f"""
139
+ SELECT DISTINCT
140
+ p.id,
141
+ p.title,
142
+ p.source_url,
143
+ p.author,
144
+ p.published_date,
145
+ p.image_url,
146
+ p.type,
147
+ pr.name as provider_name,
148
+ string_agg(DISTINCT t.name, ', ') as tags
149
+ FROM posts p
150
+ LEFT JOIN post_tags pt ON p.id = pt.post_id
151
+ LEFT JOIN tags t ON pt.tag_id = t.id
152
+ LEFT JOIN providers pr ON p.provider_id = pr.id
153
+ WHERE
154
+ LOWER(t.name) = ANY({tags_array})
155
+ OR ({keyword_where})
156
+ GROUP BY p.id, p.title, p.source_url, p.author, p.published_date, p.image_url, p.type, pr.name
157
+ ORDER BY p.published_date DESC NULLS LAST
158
+ LIMIT {limit};
159
+ """
160
+
161
+ @staticmethod
162
+ def search_by_author(author: str, limit: int = 9) -> str:
163
+ """
164
+ Search posts by specific author or organization.
165
+
166
+ Args:
167
+ author: Author name to search for
168
+ limit: Maximum number of results
169
+
170
+ Returns:
171
+ SQL query string
172
+ """
173
+ # Escape special regex characters
174
+ author_escaped = author.lower().replace('\\', '\\\\').replace('.', '\\.').replace('+', '\\+')
175
+
176
+ return f"""
177
+ SELECT DISTINCT
178
+ p.id,
179
+ p.title,
180
+ p.source_url,
181
+ p.author,
182
+ p.published_date,
183
+ p.image_url,
184
+ p.type,
185
+ pr.name as provider_name
186
+ FROM posts p
187
+ LEFT JOIN providers pr ON p.provider_id = pr.id
188
+ WHERE
189
+ p.author ~* '\\m{author_escaped}\\M'
190
+ OR pr.name ~* '\\m{author_escaped}\\M'
191
+ ORDER BY p.published_date DESC NULLS LAST
192
+ LIMIT {limit};
193
+ """
194
+
195
+ @staticmethod
196
+ def search_recent(days: int = 90, limit: int = 9) -> str:
197
+ """
198
+ Search for recent posts within the last N days.
199
+
200
+ Args:
201
+ days: Number of days to look back
202
+ limit: Maximum number of results
203
+
204
+ Returns:
205
+ SQL query string
206
+ """
207
+ return f"""
208
+ SELECT DISTINCT
209
+ p.id,
210
+ p.title,
211
+ p.source_url,
212
+ p.author,
213
+ p.published_date,
214
+ p.image_url,
215
+ p.type,
216
+ pr.name as provider_name
217
+ FROM posts p
218
+ LEFT JOIN providers pr ON p.provider_id = pr.id
219
+ WHERE
220
+ p.published_date >= CURRENT_DATE - INTERVAL '{days} days'
221
+ ORDER BY p.published_date DESC
222
+ LIMIT {limit};
223
+ """
224
+
225
+ @staticmethod
226
+ def search_by_type(post_type: str, limit: int = 9) -> str:
227
+ """
228
+ Search by post type (spotlight, insight, resource).
229
+
230
+ Args:
231
+ post_type: Type of post (spotlight, insight, resource)
232
+ limit: Maximum number of results
233
+
234
+ Returns:
235
+ SQL query string
236
+ """
237
+ return f"""
238
+ SELECT DISTINCT
239
+ p.id,
240
+ p.title,
241
+ p.source_url,
242
+ p.author,
243
+ p.published_date,
244
+ p.image_url,
245
+ p.type,
246
+ pr.name as provider_name
247
+ FROM posts p
248
+ LEFT JOIN providers pr ON p.provider_id = pr.id
249
+ WHERE p.type = '{post_type}'
250
+ ORDER BY p.published_date DESC NULLS LAST
251
+ LIMIT {limit};
252
+ """
253
+
254
+
255
+ def generate_query(search_type: str, **kwargs) -> str:
256
+ """
257
+ Generate SQL query based on search type.
258
+
259
+ Args:
260
+ search_type: Type of search (keyword, tag, hybrid, author, recent, type)
261
+ **kwargs: Parameters for the specific search type
262
+
263
+ Returns:
264
+ SQL query string
265
+ """
266
+ functions = {
267
+ "keyword": QueryFunctions.keyword_search,
268
+ "tag": QueryFunctions.tag_search,
269
+ "hybrid": QueryFunctions.hybrid_search,
270
+ "author": QueryFunctions.search_by_author,
271
+ "recent": QueryFunctions.search_recent,
272
+ "type": QueryFunctions.search_by_type,
273
+ }
274
+
275
+ if search_type not in functions:
276
+ raise ValueError(f"Unknown search type: {search_type}")
277
+
278
+ return functions[search_type](**kwargs)
279
+
280
+
281
+ # Example usage
282
+ if __name__ == "__main__":
283
+ # Test keyword search
284
+ print("=== KEYWORD SEARCH ===")
285
+ print(QueryFunctions.keyword_search(["F1", "racing"]))
286
+
287
+ print("\n=== TAG SEARCH ===")
288
+ print(QueryFunctions.tag_search(["dataviz", "interactive"]))
289
+
290
+ print("\n=== HYBRID SEARCH ===")
291
+ print(QueryFunctions.hybrid_search(
292
+ keywords=["visualization"],
293
+ tags=["dataviz", "interactive"]
294
+ ))
295
+
296
+ print("\n=== AUTHOR SEARCH ===")
297
+ print(QueryFunctions.search_by_author("New York Times"))
298
+
299
+ print("\n=== RECENT POSTS ===")
300
+ print(QueryFunctions.search_recent(days=30))