yangg40 commited on
Commit
f4d5bcd
·
verified ·
1 Parent(s): 73de42d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +186 -0
app.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ HuggingFace Space - SPECTER2 Embedding API
3
+ Academic paper embeddings using SPECTER2 with adapters
4
+ """
5
+
6
+ import os
7
+ from flask import Flask, request, jsonify
8
+ from transformers import AutoTokenizer
9
+ from adapters import AutoAdapterModel
10
+ import torch
11
+ import logging
12
+
13
+ # Configure logging
14
+ logging.basicConfig(level=logging.INFO)
15
+ logger = logging.getLogger(__name__)
16
+
17
+ # Initialize Flask app
18
+ app = Flask(__name__)
19
+
20
+ # Load SPECTER2 model with adapters
21
+ logger.info("Loading SPECTER2 base model and tokenizer...")
22
+ tokenizer = AutoTokenizer.from_pretrained('allenai/specter2_base')
23
+ model = AutoAdapterModel.from_pretrained('allenai/specter2_base')
24
+
25
+ logger.info("Loading SPECTER2 proximity adapter...")
26
+ # Load the proximity adapter for similarity/retrieval tasks
27
+ model.load_adapter("allenai/specter2", source="hf", load_as="specter2", set_active=True)
28
+ logger.info("SPECTER2 model loaded successfully!")
29
+
30
+ # Move to GPU if available
31
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
32
+ model = model.to(device)
33
+ logger.info(f"Using device: {device}")
34
+
35
+
36
+ def get_embeddings(texts):
37
+ """
38
+ Generate SPECTER2 embeddings for a list of texts
39
+
40
+ Args:
41
+ texts: List of strings (paper titles + abstracts)
42
+
43
+ Returns:
44
+ numpy array of embeddings (batch_size, 768)
45
+ """
46
+ # Tokenize
47
+ inputs = tokenizer(
48
+ texts,
49
+ padding=True,
50
+ truncation=True,
51
+ return_tensors="pt",
52
+ max_length=512
53
+ ).to(device)
54
+
55
+ # Generate embeddings
56
+ with torch.no_grad():
57
+ output = model(**inputs)
58
+ # Use [CLS] token embedding (first token)
59
+ embeddings = output.last_hidden_state[:, 0, :]
60
+
61
+ return embeddings.cpu().numpy()
62
+
63
+
64
+ @app.route('/')
65
+ def health():
66
+ """Health check endpoint"""
67
+ return jsonify({
68
+ "status": "healthy",
69
+ "model": "allenai/specter2",
70
+ "adapter": "proximity (similarity/retrieval)",
71
+ "dimensions": 768,
72
+ "device": str(device),
73
+ "endpoints": {
74
+ "/embed": "POST - Generate embedding for single text",
75
+ "/batch_embed": "POST - Generate embeddings for multiple texts"
76
+ }
77
+ })
78
+
79
+
80
+ @app.route('/embed', methods=['POST'])
81
+ def embed_text():
82
+ """
83
+ Generate embedding for a single text query
84
+
85
+ Request body:
86
+ {
87
+ "text": "Your paper title and abstract here"
88
+ }
89
+
90
+ Response:
91
+ {
92
+ "embedding": [0.123, -0.456, ...],
93
+ "dimensions": 768
94
+ }
95
+ """
96
+ try:
97
+ data = request.get_json()
98
+
99
+ if not data or 'text' not in data:
100
+ return jsonify({
101
+ "error": "Missing 'text' field in request body"
102
+ }), 400
103
+
104
+ text = data['text']
105
+
106
+ if not isinstance(text, str) or len(text.strip()) == 0:
107
+ return jsonify({
108
+ "error": "Text must be a non-empty string"
109
+ }), 400
110
+
111
+ # Generate embedding
112
+ embeddings = get_embeddings([text])
113
+
114
+ return jsonify({
115
+ "embedding": embeddings[0].tolist(),
116
+ "dimensions": len(embeddings[0])
117
+ })
118
+
119
+ except Exception as e:
120
+ logger.error(f"Error generating embedding: {str(e)}")
121
+ return jsonify({
122
+ "error": "Internal server error",
123
+ "message": str(e)
124
+ }), 500
125
+
126
+
127
+ @app.route('/batch_embed', methods=['POST'])
128
+ def batch_embed_texts():
129
+ """
130
+ Generate embeddings for multiple texts (batch processing)
131
+
132
+ Request body:
133
+ {
134
+ "texts": ["Paper 1 title and abstract", "Paper 2 title and abstract", ...]
135
+ }
136
+
137
+ Response:
138
+ {
139
+ "embeddings": [[0.123, ...], [0.456, ...], ...],
140
+ "count": 2,
141
+ "dimensions": 768
142
+ }
143
+ """
144
+ try:
145
+ data = request.get_json()
146
+
147
+ if not data or 'texts' not in data:
148
+ return jsonify({
149
+ "error": "Missing 'texts' field in request body"
150
+ }), 400
151
+
152
+ texts = data['texts']
153
+
154
+ if not isinstance(texts, list) or len(texts) == 0:
155
+ return jsonify({
156
+ "error": "Texts must be a non-empty list"
157
+ }), 400
158
+
159
+ # Limit batch size to prevent abuse
160
+ if len(texts) > 100:
161
+ return jsonify({
162
+ "error": "Batch size too large (max 100 texts)"
163
+ }), 400
164
+
165
+ # Generate embeddings
166
+ embeddings = get_embeddings(texts)
167
+
168
+ return jsonify({
169
+ "embeddings": embeddings.tolist(),
170
+ "count": len(embeddings),
171
+ "dimensions": embeddings.shape[1]
172
+ })
173
+
174
+ except Exception as e:
175
+ logger.error(f"Error generating batch embeddings: {str(e)}")
176
+ return jsonify({
177
+ "error": "Internal server error",
178
+ "message": str(e)
179
+ }), 500
180
+
181
+
182
+ if __name__ == '__main__':
183
+ # HuggingFace Spaces requires the app to listen on port 7860
184
+ port = int(os.environ.get('PORT', 7860))
185
+ logger.info(f"Starting server on port {port}...")
186
+ app.run(host='0.0.0.0', port=port, debug=False)