File size: 1,191 Bytes
478805f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 |
from fastapi import FastAPI, File, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from transformers import AutoImageProcessor, AutoModel
from PIL import Image
import torch
import io
app = FastAPI()
# CORS (para pruebas locales o producción cruzada)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Cambia esto en producción
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Carga del modelo y procesador
processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base")
model = AutoModel.from_pretrained("facebook/dinov2-base")
model.eval()
@app.post("/embedding")
async def get_embedding(file: UploadFile = File(...)):
try:
image_bytes = await file.read()
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
inputs = processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
# Promedio de los embeddings de todos los tokens (sin CLS)
embedding = outputs.last_hidden_state.mean(dim=1).squeeze().tolist()
return {"embedding": embedding}
except Exception as e:
return {"error": str(e)}
|