import torch import json from fastapi import FastAPI from pydantic import BaseModel from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline from typing import Optional # --- 1. Load Model and Tokenizer --- # Define the path to your trained model MODEL_PATH = "./bertmodel" # Define the path to your knowledge base KNOWLEDGE_BASE_PATH = "womens_legal_rights_india_10000.json" print("Loading tokenizer...") tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) print("Loading classification model...") model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH) # Set device to GPU (cuda:0) if available, otherwise CPU device = 0 if torch.cuda.is_available() else -1 print(f"Creating classification pipeline on device: {'cuda' if device == 0 else 'cpu'}...") # Create the text-classification pipeline classifier = pipeline( "text-classification", model=model, tokenizer=tokenizer, device=device ) print("Classification pipeline loaded successfully.") # --- 2. Load Knowledge Base (Answers) --- # We will now store a dictionary for each intent intent_details_map = {} print(f"Loading knowledge base from: {KNOWLEDGE_BASE_PATH}") try: with open(KNOWLEDGE_BASE_PATH, 'r', encoding='utf-8') as f: knowledge_base_data = json.load(f) # Create a lookup map: Intent -> {answer, source} # We only need to add each intent once, since the answer/source is the same for item in knowledge_base_data: if item['intent'] not in intent_details_map: intent_details_map[item['intent']] = { "answer": item.get('answer', 'No answer found.'), "source": item.get('source', 'No source found.') } print(f"Knowledge base loaded with {len(intent_details_map)} intent-to-detail mappings.") except FileNotFoundError: print(f"CRITICAL ERROR: Knowledge base file not found at {KNOWLEDGE_BASE_PATH}") except Exception as e: print(f"Error loading knowledge base: {e}") # --- 3. Initialize FastAPI App --- app = FastAPI( title="Legal Intent & Answer API", description="API to predict the intent of a legal question and provide a suitable answer.", version="1.1.0" ) # --- 4. Define Request and Response Models --- # This is what the API will return # This is what the user must send in their POST request class Query(BaseModel): text: str # This is what the API will return (now includes 'source') class PredictionResponse(BaseModel): query: str predicted_intent: str confidence_score: float answer: str source: Optional[str] # Make source optional in case it's missing # --- 5. Define API Endpoints --- @app.get("/") def read_root(): """ Root endpoint for health check. """ return {"status": "API is running", "message": "Post to /predict with a 'text' field to get an intent and answer and source."} @app.post("/predict", response_model=PredictionResponse) def predict_intent(query: Query): """ Predicts the intent of a given legal question and provides a suitable answer. """ print(f"Received query: {query.text}") # 1. Get prediction from model model_result = classifier(query.text)[0] predicted_intent = model_result['label'] confidence_score = model_result['score'] # 2. Retrieve details (answer and source) from our knowledge base fallback_details = { "answer": "Could not find a specific answer for this intent. Please rephrase your question or contact a legal professional for advice.", "source": None } details = intent_details_map.get(predicted_intent, fallback_details) # 3. Return the combined response return { "query": query.text, "predicted_intent": predicted_intent, "confidence_score": confidence_score, "answer": details.get('answer'), "source": details.get('source') }