1- import os
2- import json
3- import time
41import boto3
52import logging
6- from typing import List , Optional
7- from dotenv import load_dotenv
8- from concurrent .futures import ThreadPoolExecutor
9- from fastapi import FastAPI , HTTPException , Depends , Request
3+ from typing import Optional
4+ from fastapi import FastAPI , HTTPException , Request
105from fastapi .middleware .cors import CORSMiddleware
116from pydantic import BaseModel , Field , ValidationError
12- from langchain_aws import BedrockEmbeddings , BedrockLLM
13- from langchain .text_splitter import RecursiveCharacterTextSplitter
14- from langchain_community .document_loaders import PyPDFDirectoryLoader
15- from langchain_community .vectorstores import FAISS
16- from langchain .prompts import PromptTemplate
17- from langchain .chains import RetrievalQA
7+ from langchain_aws import BedrockEmbeddings
188from fastapi .responses import JSONResponse
19-
20- # Load environment variables from .env file
21- load_dotenv ()
22-
23- # Load configuration from config.json
24- try :
25- with open ("src/config.json" ) as config_file :
26- config = json .load (config_file )
27- except FileNotFoundError :
28- raise RuntimeError ("Configuration file 'src/config.json' not found." )
29-
30- # Configuration variables from config.json
31- DATA_DIRECTORY = config .get ("DATA_DIRECTORY" )
32- FAISS_INDEX_PATH = config .get ("FAISS_INDEX_PATH" )
33- TITAN_MODEL_ID = config .get ("TITAN_MODEL_ID" )
34- LLAMA_MODEL_ID = config .get ("LLAMA_MODEL_ID" )
35- CHUNK_SIZE = config .get ("CHUNK_SIZE" )
36- CHUNK_OVERLAP = config .get ("CHUNK_OVERLAP" )
37- MAX_THREADS = config .get ("MAX_THREADS" )
38- LOG_LEVEL = config .get ("LOG_LEVEL" , "INFO" ).upper ()
39-
40- # Validate configuration variables
41- required_configs = [DATA_DIRECTORY , FAISS_INDEX_PATH , TITAN_MODEL_ID , LLAMA_MODEL_ID , CHUNK_SIZE , CHUNK_OVERLAP , MAX_THREADS ]
42- if any (config is None for config in required_configs ):
43- raise ValueError ("Missing required configuration in config.json" )
9+ from contextlib import asynccontextmanager
10+ from src .chatbot .config import DATA_DIRECTORY , FAISS_INDEX_PATH , TITAN_MODEL_ID , LLAMA_MODEL_ID , LOG_LEVEL
11+ from src .chatbot .services import FAISSManager , PDFDocumentProcessor , LLMService
4412
4513# Initialize FastAPI app
46- app = FastAPI ()
14+ @asynccontextmanager
15+ async def lifespan (app : FastAPI ):
16+ try :
17+ logger .info ("Lifespan event triggered. Automatically running the /create_index endpoint..." )
18+ await create_index () # Automatically trigger create_index during startup
19+ except Exception as e :
20+ logger .error (f"Error during lifespan event: { e } " , exc_info = True )
21+ yield # Continue with the application lifecycle
22+
23+ app = FastAPI (lifespan = lifespan )
4724
4825# Add CORS middleware
4926app .add_middleware (
6138)
6239logger = logging .getLogger (__name__ )
6340
64- # Prompt template for tax-related queries in India
65- prompt_template = """
66- You are a knowledgeable Indian tax advisor. Use the following pieces of context to provide a detailed answer to the question at the end.
67- Provide at least 250 words with detailed explanations, practical examples where applicable, and include an analysis at the end.
68-
69- Examples:
70- 1.
71- Human: What are the different types of income tax in India?
72- Assistant: In India, income tax is categorized into various heads based on the source of income. The primary types are:
73- - **Salaries**: Income earned from employment.
74- - **House Property**: Income from rental properties.
75- - **Business or Profession**: Profits from business activities or professional services.
76- - **Capital Gains**: Profits from the sale of capital assets such as stocks or real estate.
77- - **Other Sources**: Includes interest income, dividends, and other miscellaneous income.
78- Each category is subject to specific tax rates and exemptions under the Income Tax Act, 1961.
79-
80- Analysis: Understanding the different types of income tax is crucial for individuals and businesses to comply with tax regulations and optimize their tax liabilities.
81-
82- 2.
83- Human: Can you explain the Goods and Services Tax (GST) framework in India?
84- Assistant: The Goods and Services Tax (GST) is a comprehensive indirect tax levied on the supply of goods and services in India, implemented on July 1, 2017. It subsumes various indirect taxes such as Value Added Tax (VAT), Central Excise Duty, and Service Tax. The GST structure consists of three components:
85- - **CGST**: Central Goods and Services Tax, collected by the central government.
86- - **SGST**: State Goods and Services Tax, collected by the state government.
87- - **IGST**: Integrated Goods and Services Tax, applied to inter-state transactions.
88- The GST aims to simplify the tax structure, enhance compliance, and eliminate the cascading effect of multiple taxes.
89-
90- Analysis: The GST framework represents a significant reform in India's tax system, promoting a unified market and fostering ease of doing business.
91-
92- <context>
93- {context}
94- </context>
95-
96- Question: {question}
97-
98- Assistant:
99- """
100-
101- PROMPT = PromptTemplate (template = prompt_template , input_variables = ["context" , "question" ])
102-
10341# Pydantic model for the question input
10442class QuestionRequest (BaseModel ):
10543 question : str = Field (..., json_schema_extra = {"example" : "What is the capital of France?" })
10644 aws_access_key_id : Optional [str ] = Field (None , json_schema_extra = {"example" : "your_access_key_id" })
10745 aws_secret_access_key : Optional [str ] = Field (None , json_schema_extra = {"example" : "your_secret_access_key" })
10846 aws_default_region : Optional [str ] = Field (None , json_schema_extra = {"example" : "your_region" })
10947
110- # PDF Document Processor
111- class PDFDocumentProcessor :
112- def __init__ (self , data_directory : str ):
113- self .data_directory = data_directory
114-
115- def load_and_chunk_documents (self ) -> List [str ]:
116- start_time = time .time ()
117- try :
118- loader = PyPDFDirectoryLoader (self .data_directory )
119- logger .info (f"Loading PDFs from { self .data_directory } ..." )
120- documents = loader .load ()
121-
122- text_splitter = RecursiveCharacterTextSplitter (chunk_size = CHUNK_SIZE , chunk_overlap = CHUNK_OVERLAP )
123-
124- # Parallel chunking
125- with ThreadPoolExecutor (max_workers = MAX_THREADS ) as executor :
126- chunked_documents = list (executor .map (text_splitter .split_documents , [documents ]))
127-
128- logger .info (f"Document loading and chunking completed in { time .time () - start_time :.2f} seconds." )
129- return [chunk for sublist in chunked_documents for chunk in sublist ]
130- except FileNotFoundError :
131- logger .error (f"Data directory '{ self .data_directory } ' not found." )
132- raise HTTPException (status_code = 404 , detail = "Data directory not found" )
133- except Exception as e :
134- logger .error (f"Error loading and chunking documents: { e } " , exc_info = True )
135- raise HTTPException (status_code = 500 , detail = "Error processing documents" )
136-
137- # FAISS Manager
138- class FAISSManager :
139- def __init__ (self , index_path : str , embeddings ):
140- self .index_path = index_path
141- self .embeddings = embeddings
142- self ._ensure_index_directory_exists ()
143-
144- def _ensure_index_directory_exists (self ):
145- """Ensure the directory for the FAISS index exists."""
146- if not os .path .exists (self .index_path ):
147- os .makedirs (self .index_path )
148- logger .info (f"Created directory for FAISS index at { self .index_path } ." )
149- else :
150- logger .info (f"FAISS index directory already exists at { self .index_path } ." )
151-
152- def create_and_save_vector_store (self , chunked_documents : List [str ]):
153- try :
154- vectorstore_faiss = FAISS .from_documents (chunked_documents , self .embeddings )
155- vectorstore_faiss .save_local (self .index_path )
156- logger .info (f"FAISS index created and saved to { self .index_path } ." )
157- except Exception as e :
158- logger .error (f"Error creating and saving FAISS vector store: { e } " , exc_info = True )
159- raise HTTPException (status_code = 500 , detail = "Error creating FAISS vector store" )
160-
161- def load_vector_store (self ):
162- try :
163- logger .info (f"Loading FAISS index from { self .index_path } ..." )
164- return FAISS .load_local (self .index_path , self .embeddings , allow_dangerous_deserialization = True )
165- except FileNotFoundError :
166- logger .error (f"FAISS index file '{ self .index_path } ' not found." )
167- raise HTTPException (status_code = 404 , detail = "FAISS index not found" )
168- except Exception as e :
169- logger .error (f"Error loading FAISS vector store: { e } " , exc_info = True )
170- raise HTTPException (status_code = 500 , detail = "Error loading FAISS vector store" )
171-
172- # LLM Service
173- class LLMService :
174- def __init__ (self , model_id : str , client ):
175- self .model_id = model_id
176- self .client = client
177-
178- def initialize_llm (self ):
179- try :
180- logger .info (f"Initializing LLM with model ID: { self .model_id } " )
181- return BedrockLLM (model_id = self .model_id , client = self .client )
182- except Exception as e :
183- logger .error (f"Error initializing LLM: { e } " , exc_info = True )
184- raise HTTPException (status_code = 500 , detail = "Error initializing LLM" )
185-
186- def generate_response (self , llm , vectorstore_faiss , query : str ):
187- try :
188- start_time = time .time ()
189- logger .info (f"Generating response for query: '{ query } '" )
190- qa = RetrievalQA .from_chain_type (
191- llm = llm ,
192- chain_type = "stuff" ,
193- retriever = vectorstore_faiss .as_retriever (search_type = "similarity" , search_kwargs = {"k" : 3 }),
194- return_source_documents = True ,
195- chain_type_kwargs = {"prompt" : PROMPT }
196- )
197- result = qa .invoke ({"query" : query })
198- logger .info (f"Response generated in { time .time () - start_time :.2f} seconds." )
199- return result ['result' ]
200- except Exception as e :
201- logger .error (f"Error generating LLM response: { e } " , exc_info = True )
202- raise HTTPException (status_code = 500 , detail = "Error generating response" )
203-
20448# Middleware to log requests and responses
20549@app .middleware ("http" )
20650async def log_requests (request : Request , call_next ):
@@ -237,62 +81,53 @@ async def create_index():
23781 processor = PDFDocumentProcessor (data_directory = DATA_DIRECTORY )
23882 chunked_documents = processor .load_and_chunk_documents ()
23983
240- # Initialize AWS client
241- bedrock_client = boto3 .client (
242- service_name = "bedrock-runtime" ,
243- aws_access_key_id = os .getenv ("AWS_ACCESS_KEY_ID" ),
244- aws_secret_access_key = os .getenv ("AWS_SECRET_ACCESS_KEY" ),
245- region_name = os .getenv ("AWS_DEFAULT_REGION" ),
246- )
247-
248- # Create and save FAISS index
249- faiss_manager = FAISSManager (FAISS_INDEX_PATH , BedrockEmbeddings (model_id = TITAN_MODEL_ID , client = bedrock_client ))
84+ # Load embeddings and create FAISS index
85+ embeddings = BedrockEmbeddings (model_id = TITAN_MODEL_ID )
86+ faiss_manager = FAISSManager (index_path = FAISS_INDEX_PATH , embeddings = embeddings )
25087 faiss_manager .create_and_save_vector_store (chunked_documents )
25188
252- return {"detail" : "FAISS index created successfully" }
89+ return {"message" : "FAISS index created successfully." }
90+ except HTTPException as http_exc :
91+ raise http_exc
25392 except Exception as e :
254- logger .error (f"Error in /create_index endpoint : { e } " , exc_info = True )
93+ logger .error (f"Error creating FAISS index : { e } " , exc_info = True )
25594 raise HTTPException (status_code = 500 , detail = "Error creating FAISS index" )
25695
257- # Endpoint to ask a question
258- @app .post ("/ask " )
259- async def ask_question (request : QuestionRequest ):
96+ # Question answering endpoint
97+ @app .post ("/answer " )
98+ async def answer_question (request : QuestionRequest ):
26099 try :
261- start_time = time .time ()
262- logger .info (f"Received question: '{ request .question } '" )
263-
264- # Initialize AWS client with keys from the request if provided
265- aws_access_key_id = request .aws_access_key_id or os .getenv ("AWS_ACCESS_KEY_ID" )
266- aws_secret_access_key = request .aws_secret_access_key or os .getenv ("AWS_SECRET_ACCESS_KEY" )
267- aws_default_region = request .aws_default_region or os .getenv ("AWS_DEFAULT_REGION" )
268-
269- if not all ([aws_access_key_id , aws_secret_access_key , aws_default_region ]):
270- raise HTTPException (status_code = 400 , detail = "AWS keys are required." )
271-
272- bedrock_client = boto3 .client (
273- service_name = "bedrock-runtime" ,
274- aws_access_key_id = aws_access_key_id ,
275- aws_secret_access_key = aws_secret_access_key ,
276- region_name = aws_default_region ,
277- )
100+ logger .info (f"Received question: { request .question } " )
101+
102+ # Validate AWS credentials
103+ if request .aws_access_key_id and request .aws_secret_access_key and request .aws_default_region :
104+ logger .info ("AWS credentials provided in the request." )
105+ # Initialize Boto3 client using provided credentials
106+ client = boto3 .Session (
107+ aws_access_key_id = request .aws_access_key_id ,
108+ aws_secret_access_key = request .aws_secret_access_key ,
109+ region_name = request .aws_default_region
110+ ).client ("bedrock-runtime" )
111+ else :
112+ # Initialize Boto3 client using environment variables or credentials from AWS CLI
113+ client = boto3 .client ("bedrock-runtime" )
114+ logger .info ("Using AWS credentials from environment variables or AWS CLI configuration." )
278115
279116 # Load FAISS index
280- faiss_manager = FAISSManager (FAISS_INDEX_PATH , BedrockEmbeddings (model_id = TITAN_MODEL_ID , client = bedrock_client ))
281- faiss_index = faiss_manager .load_vector_store ()
117+ embeddings = BedrockEmbeddings (model_id = TITAN_MODEL_ID )
118+ faiss_manager = FAISSManager (index_path = FAISS_INDEX_PATH , embeddings = embeddings )
119+ vectorstore_faiss = faiss_manager .load_vector_store ()
282120
283121 # Initialize LLM
284- llm_service = LLMService (model_id = LLAMA_MODEL_ID , client = bedrock_client )
122+ llm_service = LLMService (model_id = LLAMA_MODEL_ID , client = client )
285123 llm = llm_service .initialize_llm ()
286124
287125 # Generate response
288- response = llm_service .generate_response (llm , faiss_index , request .question )
289-
290- logger .info (f"Question processed in { time .time () - start_time :.2f} seconds." )
126+ response = llm_service .generate_response (llm = llm , vectorstore_faiss = vectorstore_faiss , query = request .question )
291127 return {"answer" : response }
292128
293- except HTTPException as e :
294- logger .error (f"HTTP exception: { e .detail } " , exc_info = True )
295- raise e
129+ except HTTPException as http_exc :
130+ raise http_exc
296131 except Exception as e :
297- logger .error (f"Error in /ask endpoint : { e } " , exc_info = True )
298- raise HTTPException (status_code = 500 , detail = "Error processing request " )
132+ logger .error (f"Error processing question : { e } " , exc_info = True )
133+ raise HTTPException (status_code = 500 , detail = "Error processing question " )
0 commit comments