Spaces:
Sleeping
Sleeping
| """ | |
| HuggingFace Space - SPECTER2 Embedding API | |
| Academic paper embeddings using SPECTER2 with adapters | |
| """ | |
| import os | |
| from flask import Flask, request, jsonify | |
| from transformers import AutoTokenizer | |
| from adapters import AutoAdapterModel | |
| import torch | |
| import logging | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Initialize Flask app | |
| app = Flask(__name__) | |
| # Load SPECTER2 model with adapters | |
| logger.info("Loading SPECTER2 base model and tokenizer...") | |
| tokenizer = AutoTokenizer.from_pretrained('allenai/specter2_base') | |
| model = AutoAdapterModel.from_pretrained('allenai/specter2_base') | |
| logger.info("Loading SPECTER2 proximity adapter...") | |
| # Load the proximity adapter for similarity/retrieval tasks | |
| model.load_adapter("allenai/specter2", source="hf", load_as="specter2", set_active=True) | |
| logger.info("SPECTER2 model loaded successfully!") | |
| # Move to GPU if available | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| model = model.to(device) | |
| logger.info(f"Using device: {device}") | |
| def get_embeddings(texts): | |
| """ | |
| Generate SPECTER2 embeddings for a list of texts | |
| Args: | |
| texts: List of strings (paper titles + abstracts) | |
| Returns: | |
| numpy array of embeddings (batch_size, 768) | |
| """ | |
| # Tokenize | |
| inputs = tokenizer( | |
| texts, | |
| padding=True, | |
| truncation=True, | |
| return_tensors="pt", | |
| max_length=512 | |
| ).to(device) | |
| # Generate embeddings | |
| with torch.no_grad(): | |
| output = model(**inputs) | |
| # Use [CLS] token embedding (first token) | |
| embeddings = output.last_hidden_state[:, 0, :] | |
| return embeddings.cpu().numpy() | |
| def health(): | |
| """Health check endpoint""" | |
| return jsonify({ | |
| "status": "healthy", | |
| "model": "allenai/specter2", | |
| "adapter": "proximity (similarity/retrieval)", | |
| "dimensions": 768, | |
| "device": str(device), | |
| "endpoints": { | |
| "/embed": "POST - Generate embedding for single text", | |
| "/batch_embed": "POST - Generate embeddings for multiple texts" | |
| } | |
| }) | |
| def embed_text(): | |
| """ | |
| Generate embedding for a single text query | |
| Request body: | |
| { | |
| "text": "Your paper title and abstract here" | |
| } | |
| Response: | |
| { | |
| "embedding": [0.123, -0.456, ...], | |
| "dimensions": 768 | |
| } | |
| """ | |
| try: | |
| data = request.get_json() | |
| if not data or 'text' not in data: | |
| return jsonify({ | |
| "error": "Missing 'text' field in request body" | |
| }), 400 | |
| text = data['text'] | |
| if not isinstance(text, str) or len(text.strip()) == 0: | |
| return jsonify({ | |
| "error": "Text must be a non-empty string" | |
| }), 400 | |
| # Generate embedding | |
| embeddings = get_embeddings([text]) | |
| return jsonify({ | |
| "embedding": embeddings[0].tolist(), | |
| "dimensions": len(embeddings[0]) | |
| }) | |
| except Exception as e: | |
| logger.error(f"Error generating embedding: {str(e)}") | |
| return jsonify({ | |
| "error": "Internal server error", | |
| "message": str(e) | |
| }), 500 | |
| def batch_embed_texts(): | |
| """ | |
| Generate embeddings for multiple texts (batch processing) | |
| Request body: | |
| { | |
| "texts": ["Paper 1 title and abstract", "Paper 2 title and abstract", ...] | |
| } | |
| Response: | |
| { | |
| "embeddings": [[0.123, ...], [0.456, ...], ...], | |
| "count": 2, | |
| "dimensions": 768 | |
| } | |
| """ | |
| try: | |
| data = request.get_json() | |
| if not data or 'texts' not in data: | |
| return jsonify({ | |
| "error": "Missing 'texts' field in request body" | |
| }), 400 | |
| texts = data['texts'] | |
| if not isinstance(texts, list) or len(texts) == 0: | |
| return jsonify({ | |
| "error": "Texts must be a non-empty list" | |
| }), 400 | |
| # Limit batch size to prevent abuse | |
| if len(texts) > 100: | |
| return jsonify({ | |
| "error": "Batch size too large (max 100 texts)" | |
| }), 400 | |
| # Generate embeddings | |
| embeddings = get_embeddings(texts) | |
| return jsonify({ | |
| "embeddings": embeddings.tolist(), | |
| "count": len(embeddings), | |
| "dimensions": embeddings.shape[1] | |
| }) | |
| except Exception as e: | |
| logger.error(f"Error generating batch embeddings: {str(e)}") | |
| return jsonify({ | |
| "error": "Internal server error", | |
| "message": str(e) | |
| }), 500 | |
| if __name__ == '__main__': | |
| # HuggingFace Spaces requires the app to listen on port 7860 | |
| port = int(os.environ.get('PORT', 7860)) | |
| logger.info(f"Starting server on port {port}...") | |
| app.run(host='0.0.0.0', port=port, debug=False) | |