|
|
from fastapi import FastAPI, File, UploadFile |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from transformers import AutoImageProcessor, AutoModel |
|
|
from PIL import Image |
|
|
import torch |
|
|
import io |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
|
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=["*"], |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
|
|
|
|
|
|
processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base") |
|
|
model = AutoModel.from_pretrained("facebook/dinov2-base") |
|
|
model.eval() |
|
|
|
|
|
@app.post("/embedding") |
|
|
async def get_embedding(file: UploadFile = File(...)): |
|
|
try: |
|
|
image_bytes = await file.read() |
|
|
image = Image.open(io.BytesIO(image_bytes)).convert("RGB") |
|
|
|
|
|
inputs = processor(images=image, return_tensors="pt") |
|
|
with torch.no_grad(): |
|
|
outputs = model(**inputs) |
|
|
|
|
|
|
|
|
embedding = outputs.last_hidden_state.mean(dim=1).squeeze().tolist() |
|
|
|
|
|
return {"embedding": embedding} |
|
|
|
|
|
except Exception as e: |
|
|
return {"error": str(e)} |
|
|
|