File size: 4,930 Bytes
f4d5bcd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
"""
HuggingFace Space - SPECTER2 Embedding API
Academic paper embeddings using SPECTER2 with adapters
"""

import os
from flask import Flask, request, jsonify
from transformers import AutoTokenizer
from adapters import AutoAdapterModel
import torch
import logging

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Initialize Flask app
app = Flask(__name__)

# Load SPECTER2 model with adapters
logger.info("Loading SPECTER2 base model and tokenizer...")
tokenizer = AutoTokenizer.from_pretrained('allenai/specter2_base')
model = AutoAdapterModel.from_pretrained('allenai/specter2_base')

logger.info("Loading SPECTER2 proximity adapter...")
# Load the proximity adapter for similarity/retrieval tasks
model.load_adapter("allenai/specter2", source="hf", load_as="specter2", set_active=True)
logger.info("SPECTER2 model loaded successfully!")

# Move to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
logger.info(f"Using device: {device}")


def get_embeddings(texts):
    """
    Generate SPECTER2 embeddings for a list of texts

    Args:
        texts: List of strings (paper titles + abstracts)

    Returns:
        numpy array of embeddings (batch_size, 768)
    """
    # Tokenize
    inputs = tokenizer(
        texts,
        padding=True,
        truncation=True,
        return_tensors="pt",
        max_length=512
    ).to(device)

    # Generate embeddings
    with torch.no_grad():
        output = model(**inputs)
        # Use [CLS] token embedding (first token)
        embeddings = output.last_hidden_state[:, 0, :]

    return embeddings.cpu().numpy()


@app.route('/')
def health():
    """Health check endpoint"""
    return jsonify({
        "status": "healthy",
        "model": "allenai/specter2",
        "adapter": "proximity (similarity/retrieval)",
        "dimensions": 768,
        "device": str(device),
        "endpoints": {
            "/embed": "POST - Generate embedding for single text",
            "/batch_embed": "POST - Generate embeddings for multiple texts"
        }
    })


@app.route('/embed', methods=['POST'])
def embed_text():
    """
    Generate embedding for a single text query

    Request body:
    {
        "text": "Your paper title and abstract here"
    }

    Response:
    {
        "embedding": [0.123, -0.456, ...],
        "dimensions": 768
    }
    """
    try:
        data = request.get_json()

        if not data or 'text' not in data:
            return jsonify({
                "error": "Missing 'text' field in request body"
            }), 400

        text = data['text']

        if not isinstance(text, str) or len(text.strip()) == 0:
            return jsonify({
                "error": "Text must be a non-empty string"
            }), 400

        # Generate embedding
        embeddings = get_embeddings([text])

        return jsonify({
            "embedding": embeddings[0].tolist(),
            "dimensions": len(embeddings[0])
        })

    except Exception as e:
        logger.error(f"Error generating embedding: {str(e)}")
        return jsonify({
            "error": "Internal server error",
            "message": str(e)
        }), 500


@app.route('/batch_embed', methods=['POST'])
def batch_embed_texts():
    """
    Generate embeddings for multiple texts (batch processing)

    Request body:
    {
        "texts": ["Paper 1 title and abstract", "Paper 2 title and abstract", ...]
    }

    Response:
    {
        "embeddings": [[0.123, ...], [0.456, ...], ...],
        "count": 2,
        "dimensions": 768
    }
    """
    try:
        data = request.get_json()

        if not data or 'texts' not in data:
            return jsonify({
                "error": "Missing 'texts' field in request body"
            }), 400

        texts = data['texts']

        if not isinstance(texts, list) or len(texts) == 0:
            return jsonify({
                "error": "Texts must be a non-empty list"
            }), 400

        # Limit batch size to prevent abuse
        if len(texts) > 100:
            return jsonify({
                "error": "Batch size too large (max 100 texts)"
            }), 400

        # Generate embeddings
        embeddings = get_embeddings(texts)

        return jsonify({
            "embeddings": embeddings.tolist(),
            "count": len(embeddings),
            "dimensions": embeddings.shape[1]
        })

    except Exception as e:
        logger.error(f"Error generating batch embeddings: {str(e)}")
        return jsonify({
            "error": "Internal server error",
            "message": str(e)
        }), 500


if __name__ == '__main__':
    # HuggingFace Spaces requires the app to listen on port 7860
    port = int(os.environ.get('PORT', 7860))
    logger.info(f"Starting server on port {port}...")
    app.run(host='0.0.0.0', port=port, debug=False)