BERT_2.0 / app.py
Sai809701
updated app.py
16778c2
import torch
import json
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
from typing import Optional
# --- 1. Load Model and Tokenizer ---
# Define the path to your trained model
MODEL_PATH = "./bertmodel"
# Define the path to your knowledge base
KNOWLEDGE_BASE_PATH = "womens_legal_rights_india_10000.json"
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
print("Loading classification model...")
model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH)
# Set device to GPU (cuda:0) if available, otherwise CPU
device = 0 if torch.cuda.is_available() else -1
print(f"Creating classification pipeline on device: {'cuda' if device == 0 else 'cpu'}...")
# Create the text-classification pipeline
classifier = pipeline(
"text-classification",
model=model,
tokenizer=tokenizer,
device=device
)
print("Classification pipeline loaded successfully.")
# --- 2. Load Knowledge Base (Answers) ---
# We will now store a dictionary for each intent
intent_details_map = {}
print(f"Loading knowledge base from: {KNOWLEDGE_BASE_PATH}")
try:
with open(KNOWLEDGE_BASE_PATH, 'r', encoding='utf-8') as f:
knowledge_base_data = json.load(f)
# Create a lookup map: Intent -> {answer, source}
# We only need to add each intent once, since the answer/source is the same
for item in knowledge_base_data:
if item['intent'] not in intent_details_map:
intent_details_map[item['intent']] = {
"answer": item.get('answer', 'No answer found.'),
"source": item.get('source', 'No source found.')
}
print(f"Knowledge base loaded with {len(intent_details_map)} intent-to-detail mappings.")
except FileNotFoundError:
print(f"CRITICAL ERROR: Knowledge base file not found at {KNOWLEDGE_BASE_PATH}")
except Exception as e:
print(f"Error loading knowledge base: {e}")
# --- 3. Initialize FastAPI App ---
app = FastAPI(
title="Legal Intent & Answer API",
description="API to predict the intent of a legal question and provide a suitable answer.",
version="1.1.0"
)
# --- 4. Define Request and Response Models ---
# This is what the API will return
# This is what the user must send in their POST request
class Query(BaseModel):
text: str
# This is what the API will return (now includes 'source')
class PredictionResponse(BaseModel):
query: str
predicted_intent: str
confidence_score: float
answer: str
source: Optional[str] # Make source optional in case it's missing
# --- 5. Define API Endpoints ---
@app.get("/")
def read_root():
"""
Root endpoint for health check.
"""
return {"status": "API is running",
"message": "Post to /predict with a 'text' field to get an intent and answer and source."}
@app.post("/predict", response_model=PredictionResponse)
def predict_intent(query: Query):
"""
Predicts the intent of a given legal question and provides a suitable answer.
"""
print(f"Received query: {query.text}")
# 1. Get prediction from model
model_result = classifier(query.text)[0]
predicted_intent = model_result['label']
confidence_score = model_result['score']
# 2. Retrieve details (answer and source) from our knowledge base
fallback_details = {
"answer": "Could not find a specific answer for this intent. Please rephrase your question or contact a legal professional for advice.",
"source": None
}
details = intent_details_map.get(predicted_intent, fallback_details)
# 3. Return the combined response
return {
"query": query.text,
"predicted_intent": predicted_intent,
"confidence_score": confidence_score,
"answer": details.get('answer'),
"source": details.get('source')
}