specter2-api / app.py
yangg40's picture
Create app.py
f4d5bcd verified
raw
history blame
4.93 kB
"""
HuggingFace Space - SPECTER2 Embedding API
Academic paper embeddings using SPECTER2 with adapters
"""
import os
from flask import Flask, request, jsonify
from transformers import AutoTokenizer
from adapters import AutoAdapterModel
import torch
import logging
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Initialize Flask app
app = Flask(__name__)
# Load SPECTER2 model with adapters
logger.info("Loading SPECTER2 base model and tokenizer...")
tokenizer = AutoTokenizer.from_pretrained('allenai/specter2_base')
model = AutoAdapterModel.from_pretrained('allenai/specter2_base')
logger.info("Loading SPECTER2 proximity adapter...")
# Load the proximity adapter for similarity/retrieval tasks
model.load_adapter("allenai/specter2", source="hf", load_as="specter2", set_active=True)
logger.info("SPECTER2 model loaded successfully!")
# Move to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
logger.info(f"Using device: {device}")
def get_embeddings(texts):
"""
Generate SPECTER2 embeddings for a list of texts
Args:
texts: List of strings (paper titles + abstracts)
Returns:
numpy array of embeddings (batch_size, 768)
"""
# Tokenize
inputs = tokenizer(
texts,
padding=True,
truncation=True,
return_tensors="pt",
max_length=512
).to(device)
# Generate embeddings
with torch.no_grad():
output = model(**inputs)
# Use [CLS] token embedding (first token)
embeddings = output.last_hidden_state[:, 0, :]
return embeddings.cpu().numpy()
@app.route('/')
def health():
"""Health check endpoint"""
return jsonify({
"status": "healthy",
"model": "allenai/specter2",
"adapter": "proximity (similarity/retrieval)",
"dimensions": 768,
"device": str(device),
"endpoints": {
"/embed": "POST - Generate embedding for single text",
"/batch_embed": "POST - Generate embeddings for multiple texts"
}
})
@app.route('/embed', methods=['POST'])
def embed_text():
"""
Generate embedding for a single text query
Request body:
{
"text": "Your paper title and abstract here"
}
Response:
{
"embedding": [0.123, -0.456, ...],
"dimensions": 768
}
"""
try:
data = request.get_json()
if not data or 'text' not in data:
return jsonify({
"error": "Missing 'text' field in request body"
}), 400
text = data['text']
if not isinstance(text, str) or len(text.strip()) == 0:
return jsonify({
"error": "Text must be a non-empty string"
}), 400
# Generate embedding
embeddings = get_embeddings([text])
return jsonify({
"embedding": embeddings[0].tolist(),
"dimensions": len(embeddings[0])
})
except Exception as e:
logger.error(f"Error generating embedding: {str(e)}")
return jsonify({
"error": "Internal server error",
"message": str(e)
}), 500
@app.route('/batch_embed', methods=['POST'])
def batch_embed_texts():
"""
Generate embeddings for multiple texts (batch processing)
Request body:
{
"texts": ["Paper 1 title and abstract", "Paper 2 title and abstract", ...]
}
Response:
{
"embeddings": [[0.123, ...], [0.456, ...], ...],
"count": 2,
"dimensions": 768
}
"""
try:
data = request.get_json()
if not data or 'texts' not in data:
return jsonify({
"error": "Missing 'texts' field in request body"
}), 400
texts = data['texts']
if not isinstance(texts, list) or len(texts) == 0:
return jsonify({
"error": "Texts must be a non-empty list"
}), 400
# Limit batch size to prevent abuse
if len(texts) > 100:
return jsonify({
"error": "Batch size too large (max 100 texts)"
}), 400
# Generate embeddings
embeddings = get_embeddings(texts)
return jsonify({
"embeddings": embeddings.tolist(),
"count": len(embeddings),
"dimensions": embeddings.shape[1]
})
except Exception as e:
logger.error(f"Error generating batch embeddings: {str(e)}")
return jsonify({
"error": "Internal server error",
"message": str(e)
}), 500
if __name__ == '__main__':
# HuggingFace Spaces requires the app to listen on port 7860
port = int(os.environ.get('PORT', 7860))
logger.info(f"Starting server on port {port}...")
app.run(host='0.0.0.0', port=port, debug=False)