""" HuggingFace Space - SPECTER2 Embedding API Academic paper embeddings using SPECTER2 with adapters """ import os from flask import Flask, request, jsonify from transformers import AutoTokenizer from adapters import AutoAdapterModel import torch import logging # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Initialize Flask app app = Flask(__name__) # Load SPECTER2 model with adapters logger.info("Loading SPECTER2 base model and tokenizer...") tokenizer = AutoTokenizer.from_pretrained('allenai/specter2_base') model = AutoAdapterModel.from_pretrained('allenai/specter2_base') logger.info("Loading SPECTER2 proximity adapter...") # Load the proximity adapter for similarity/retrieval tasks model.load_adapter("allenai/specter2", source="hf", load_as="specter2", set_active=True) logger.info("SPECTER2 model loaded successfully!") # Move to GPU if available device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = model.to(device) logger.info(f"Using device: {device}") def get_embeddings(texts): """ Generate SPECTER2 embeddings for a list of texts Args: texts: List of strings (paper titles + abstracts) Returns: numpy array of embeddings (batch_size, 768) """ # Tokenize inputs = tokenizer( texts, padding=True, truncation=True, return_tensors="pt", max_length=512 ).to(device) # Generate embeddings with torch.no_grad(): output = model(**inputs) # Use [CLS] token embedding (first token) embeddings = output.last_hidden_state[:, 0, :] return embeddings.cpu().numpy() @app.route('/') def health(): """Health check endpoint""" return jsonify({ "status": "healthy", "model": "allenai/specter2", "adapter": "proximity (similarity/retrieval)", "dimensions": 768, "device": str(device), "endpoints": { "/embed": "POST - Generate embedding for single text", "/batch_embed": "POST - Generate embeddings for multiple texts" } }) @app.route('/embed', methods=['POST']) def embed_text(): """ Generate embedding for a single text query Request body: { "text": "Your paper title and abstract here" } Response: { "embedding": [0.123, -0.456, ...], "dimensions": 768 } """ try: data = request.get_json() if not data or 'text' not in data: return jsonify({ "error": "Missing 'text' field in request body" }), 400 text = data['text'] if not isinstance(text, str) or len(text.strip()) == 0: return jsonify({ "error": "Text must be a non-empty string" }), 400 # Generate embedding embeddings = get_embeddings([text]) return jsonify({ "embedding": embeddings[0].tolist(), "dimensions": len(embeddings[0]) }) except Exception as e: logger.error(f"Error generating embedding: {str(e)}") return jsonify({ "error": "Internal server error", "message": str(e) }), 500 @app.route('/batch_embed', methods=['POST']) def batch_embed_texts(): """ Generate embeddings for multiple texts (batch processing) Request body: { "texts": ["Paper 1 title and abstract", "Paper 2 title and abstract", ...] } Response: { "embeddings": [[0.123, ...], [0.456, ...], ...], "count": 2, "dimensions": 768 } """ try: data = request.get_json() if not data or 'texts' not in data: return jsonify({ "error": "Missing 'texts' field in request body" }), 400 texts = data['texts'] if not isinstance(texts, list) or len(texts) == 0: return jsonify({ "error": "Texts must be a non-empty list" }), 400 # Limit batch size to prevent abuse if len(texts) > 100: return jsonify({ "error": "Batch size too large (max 100 texts)" }), 400 # Generate embeddings embeddings = get_embeddings(texts) return jsonify({ "embeddings": embeddings.tolist(), "count": len(embeddings), "dimensions": embeddings.shape[1] }) except Exception as e: logger.error(f"Error generating batch embeddings: {str(e)}") return jsonify({ "error": "Internal server error", "message": str(e) }), 500 if __name__ == '__main__': # HuggingFace Spaces requires the app to listen on port 7860 port = int(os.environ.get('PORT', 7860)) logger.info(f"Starting server on port {port}...") app.run(host='0.0.0.0', port=port, debug=False)