Spaces:
Sleeping
Sleeping
File size: 4,930 Bytes
f4d5bcd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 |
"""
HuggingFace Space - SPECTER2 Embedding API
Academic paper embeddings using SPECTER2 with adapters
"""
import os
from flask import Flask, request, jsonify
from transformers import AutoTokenizer
from adapters import AutoAdapterModel
import torch
import logging
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Initialize Flask app
app = Flask(__name__)
# Load SPECTER2 model with adapters
logger.info("Loading SPECTER2 base model and tokenizer...")
tokenizer = AutoTokenizer.from_pretrained('allenai/specter2_base')
model = AutoAdapterModel.from_pretrained('allenai/specter2_base')
logger.info("Loading SPECTER2 proximity adapter...")
# Load the proximity adapter for similarity/retrieval tasks
model.load_adapter("allenai/specter2", source="hf", load_as="specter2", set_active=True)
logger.info("SPECTER2 model loaded successfully!")
# Move to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
logger.info(f"Using device: {device}")
def get_embeddings(texts):
"""
Generate SPECTER2 embeddings for a list of texts
Args:
texts: List of strings (paper titles + abstracts)
Returns:
numpy array of embeddings (batch_size, 768)
"""
# Tokenize
inputs = tokenizer(
texts,
padding=True,
truncation=True,
return_tensors="pt",
max_length=512
).to(device)
# Generate embeddings
with torch.no_grad():
output = model(**inputs)
# Use [CLS] token embedding (first token)
embeddings = output.last_hidden_state[:, 0, :]
return embeddings.cpu().numpy()
@app.route('/')
def health():
"""Health check endpoint"""
return jsonify({
"status": "healthy",
"model": "allenai/specter2",
"adapter": "proximity (similarity/retrieval)",
"dimensions": 768,
"device": str(device),
"endpoints": {
"/embed": "POST - Generate embedding for single text",
"/batch_embed": "POST - Generate embeddings for multiple texts"
}
})
@app.route('/embed', methods=['POST'])
def embed_text():
"""
Generate embedding for a single text query
Request body:
{
"text": "Your paper title and abstract here"
}
Response:
{
"embedding": [0.123, -0.456, ...],
"dimensions": 768
}
"""
try:
data = request.get_json()
if not data or 'text' not in data:
return jsonify({
"error": "Missing 'text' field in request body"
}), 400
text = data['text']
if not isinstance(text, str) or len(text.strip()) == 0:
return jsonify({
"error": "Text must be a non-empty string"
}), 400
# Generate embedding
embeddings = get_embeddings([text])
return jsonify({
"embedding": embeddings[0].tolist(),
"dimensions": len(embeddings[0])
})
except Exception as e:
logger.error(f"Error generating embedding: {str(e)}")
return jsonify({
"error": "Internal server error",
"message": str(e)
}), 500
@app.route('/batch_embed', methods=['POST'])
def batch_embed_texts():
"""
Generate embeddings for multiple texts (batch processing)
Request body:
{
"texts": ["Paper 1 title and abstract", "Paper 2 title and abstract", ...]
}
Response:
{
"embeddings": [[0.123, ...], [0.456, ...], ...],
"count": 2,
"dimensions": 768
}
"""
try:
data = request.get_json()
if not data or 'texts' not in data:
return jsonify({
"error": "Missing 'texts' field in request body"
}), 400
texts = data['texts']
if not isinstance(texts, list) or len(texts) == 0:
return jsonify({
"error": "Texts must be a non-empty list"
}), 400
# Limit batch size to prevent abuse
if len(texts) > 100:
return jsonify({
"error": "Batch size too large (max 100 texts)"
}), 400
# Generate embeddings
embeddings = get_embeddings(texts)
return jsonify({
"embeddings": embeddings.tolist(),
"count": len(embeddings),
"dimensions": embeddings.shape[1]
})
except Exception as e:
logger.error(f"Error generating batch embeddings: {str(e)}")
return jsonify({
"error": "Internal server error",
"message": str(e)
}), 500
if __name__ == '__main__':
# HuggingFace Spaces requires the app to listen on port 7860
port = int(os.environ.get('PORT', 7860))
logger.info(f"Starting server on port {port}...")
app.run(host='0.0.0.0', port=port, debug=False)
|