Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| from sentence_transformers import SentenceTransformer, util | |
| # 1. Load your fine-tuned retrieval model (on CodeSearchNet - Python) | |
| # This is the model you pushed to the Hugging Face Hub after training. | |
| model_name = "juanwisz/modernbert-python-code-retrieval" | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # SentenceTransformer automatically handles tokenizer + embedding | |
| embedding_model = SentenceTransformer(model_name, device=device) | |
| # 2. Define a function to: | |
| # - Parse code snippets from the text box (split by "---") | |
| # - Compute embeddings for the user’s query and each snippet | |
| # - Return the top 3 most relevant code snippets based on cosine similarity | |
| def retrieve_top_snippets(query, code_input): | |
| # Split the code snippets by "---" | |
| # Each snippet is trimmed for cleanliness | |
| snippets = [s.strip() for s in code_input.split("---") if s.strip()] | |
| # Edge-case: if user provided no code, just return | |
| if len(snippets) == 0: | |
| return "No code snippets detected (make sure to separate them with ---)." | |
| # Embed the query and code snippets | |
| query_emb = embedding_model.encode(query, convert_to_tensor=True) | |
| snippets_emb = embedding_model.encode(snippets, convert_to_tensor=True) | |
| # Compute cosine similarities [batch_size x 1] with all code snippets | |
| cos_scores = util.cos_sim(query_emb, snippets_emb)[0] | |
| # Sort results by decreasing score | |
| # argsort(descending) means the first indices are the most relevant | |
| top_indices = torch.topk(cos_scores, k=min(3, len(snippets))).indices | |
| # Prepare text output with top 3 matches | |
| results = [] | |
| for idx in top_indices: | |
| score = cos_scores[idx].item() | |
| snippet_text = snippets[idx] | |
| results.append(f"**Score**: {score:.4f}\n```python\n{snippet_text}\n```") | |
| # Join all results nicely | |
| return "\n\n".join(results) | |
| ##################### | |
| ### Gradio Layout ### | |
| ##################### | |
| css = """ | |
| #container { | |
| margin: 0 auto; | |
| max-width: 700px; | |
| } | |
| """ | |
| with gr.Blocks(css=css) as demo: | |
| gr.Markdown("# Code Retrieval using ModernBERT\n" | |
| "Enter a natural language query and paste multiple Python code snippets, " | |
| "delimited by `---`. We'll return the top 3 matches.") | |
| with gr.Column(elem_id="container"): | |
| with gr.Row(): | |
| query_input = gr.Textbox( | |
| label="Natural Language Query", | |
| placeholder="What does your function do? e.g., 'Parse JSON from a string'" | |
| ) | |
| code_snippets_input = gr.Textbox( | |
| label="Paste Python functions (delimited by ---)", | |
| lines=10, | |
| placeholder="Example:\n---\ndef parse_json(data):\n return json.loads(data)\n---\ndef add_numbers(a, b):\n return a + b\n---" | |
| ) | |
| search_btn = gr.Button("Search", variant="primary") | |
| results_output = gr.Markdown(label="Top 3 Matches") | |
| # On click, run our retrieval function | |
| search_btn.click( | |
| fn=retrieve_top_snippets, | |
| inputs=[query_input, code_snippets_input], | |
| outputs=results_output | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |