|
|
import torch |
|
|
import json |
|
|
from fastapi import FastAPI |
|
|
from pydantic import BaseModel |
|
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline |
|
|
from typing import Optional |
|
|
|
|
|
|
|
|
|
|
|
MODEL_PATH = "./bertmodel" |
|
|
|
|
|
KNOWLEDGE_BASE_PATH = "womens_legal_rights_india_10000.json" |
|
|
|
|
|
print("Loading tokenizer...") |
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) |
|
|
|
|
|
print("Loading classification model...") |
|
|
model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH) |
|
|
|
|
|
|
|
|
device = 0 if torch.cuda.is_available() else -1 |
|
|
|
|
|
print(f"Creating classification pipeline on device: {'cuda' if device == 0 else 'cpu'}...") |
|
|
|
|
|
classifier = pipeline( |
|
|
"text-classification", |
|
|
model=model, |
|
|
tokenizer=tokenizer, |
|
|
device=device |
|
|
) |
|
|
print("Classification pipeline loaded successfully.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
intent_details_map = {} |
|
|
print(f"Loading knowledge base from: {KNOWLEDGE_BASE_PATH}") |
|
|
try: |
|
|
with open(KNOWLEDGE_BASE_PATH, 'r', encoding='utf-8') as f: |
|
|
knowledge_base_data = json.load(f) |
|
|
|
|
|
|
|
|
|
|
|
for item in knowledge_base_data: |
|
|
if item['intent'] not in intent_details_map: |
|
|
intent_details_map[item['intent']] = { |
|
|
"answer": item.get('answer', 'No answer found.'), |
|
|
"source": item.get('source', 'No source found.') |
|
|
} |
|
|
|
|
|
print(f"Knowledge base loaded with {len(intent_details_map)} intent-to-detail mappings.") |
|
|
|
|
|
except FileNotFoundError: |
|
|
print(f"CRITICAL ERROR: Knowledge base file not found at {KNOWLEDGE_BASE_PATH}") |
|
|
except Exception as e: |
|
|
print(f"Error loading knowledge base: {e}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app = FastAPI( |
|
|
title="Legal Intent & Answer API", |
|
|
description="API to predict the intent of a legal question and provide a suitable answer.", |
|
|
version="1.1.0" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Query(BaseModel): |
|
|
text: str |
|
|
|
|
|
|
|
|
class PredictionResponse(BaseModel): |
|
|
query: str |
|
|
predicted_intent: str |
|
|
confidence_score: float |
|
|
answer: str |
|
|
source: Optional[str] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/") |
|
|
def read_root(): |
|
|
""" |
|
|
Root endpoint for health check. |
|
|
""" |
|
|
return {"status": "API is running", |
|
|
"message": "Post to /predict with a 'text' field to get an intent and answer and source."} |
|
|
|
|
|
|
|
|
@app.post("/predict", response_model=PredictionResponse) |
|
|
def predict_intent(query: Query): |
|
|
""" |
|
|
Predicts the intent of a given legal question and provides a suitable answer. |
|
|
""" |
|
|
print(f"Received query: {query.text}") |
|
|
|
|
|
|
|
|
model_result = classifier(query.text)[0] |
|
|
predicted_intent = model_result['label'] |
|
|
confidence_score = model_result['score'] |
|
|
|
|
|
|
|
|
fallback_details = { |
|
|
"answer": "Could not find a specific answer for this intent. Please rephrase your question or contact a legal professional for advice.", |
|
|
"source": None |
|
|
} |
|
|
details = intent_details_map.get(predicted_intent, fallback_details) |
|
|
|
|
|
|
|
|
return { |
|
|
"query": query.text, |
|
|
"predicted_intent": predicted_intent, |
|
|
"confidence_score": confidence_score, |
|
|
"answer": details.get('answer'), |
|
|
"source": details.get('source') |
|
|
} |
|
|
|
|
|
|