|
| 1 | +import os |
| 2 | +import time |
| 3 | +import logging |
| 4 | +from concurrent.futures import ThreadPoolExecutor |
| 5 | +from typing import List |
| 6 | +from fastapi import HTTPException |
| 7 | +from langchain_aws import BedrockLLM |
| 8 | +from langchain.text_splitter import RecursiveCharacterTextSplitter |
| 9 | +from langchain_community.document_loaders import PyPDFDirectoryLoader |
| 10 | +from langchain_community.vectorstores import FAISS |
| 11 | +from langchain.chains import RetrievalQA |
| 12 | +from src.chatbot.prompts import PROMPT |
| 13 | +from src.chatbot.config import MAX_THREADS, CHUNK_SIZE, CHUNK_OVERLAP |
| 14 | + |
| 15 | +logger = logging.getLogger(__name__) |
| 16 | + |
| 17 | +# PDF Document Processor |
| 18 | +class PDFDocumentProcessor: |
| 19 | + def __init__(self, data_directory: str): |
| 20 | + self.data_directory = data_directory |
| 21 | + |
| 22 | + def load_and_chunk_documents(self) -> List[str]: |
| 23 | + start_time = time.time() |
| 24 | + try: |
| 25 | + loader = PyPDFDirectoryLoader(self.data_directory) |
| 26 | + logger.info(f"Loading PDFs from {self.data_directory}...") |
| 27 | + documents = loader.load() |
| 28 | + |
| 29 | + text_splitter = RecursiveCharacterTextSplitter(chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP) |
| 30 | + |
| 31 | + # Parallel chunking |
| 32 | + with ThreadPoolExecutor(max_workers=MAX_THREADS) as executor: |
| 33 | + chunked_documents = list(executor.map(text_splitter.split_documents, [documents])) |
| 34 | + |
| 35 | + logger.info(f"Document loading and chunking completed in {time.time() - start_time:.2f} seconds.") |
| 36 | + return [chunk for sublist in chunked_documents for chunk in sublist] |
| 37 | + except FileNotFoundError: |
| 38 | + logger.error(f"Data directory '{self.data_directory}' not found.") |
| 39 | + raise HTTPException(status_code=404, detail="Data directory not found") |
| 40 | + except Exception as e: |
| 41 | + logger.error(f"Error loading and chunking documents: {e}", exc_info=True) |
| 42 | + raise HTTPException(status_code=500, detail="Error processing documents") |
| 43 | + |
| 44 | +# FAISS Manager |
| 45 | +class FAISSManager: |
| 46 | + def __init__(self, index_path: str, embeddings): |
| 47 | + self.index_path = index_path |
| 48 | + self.embeddings = embeddings |
| 49 | + |
| 50 | + def create_and_save_vector_store(self, chunked_documents: List[str]): |
| 51 | + try: |
| 52 | + # Ensure the directory for the index exists |
| 53 | + index_dir = os.path.dirname(self.index_path) |
| 54 | + if not os.path.exists(index_dir): |
| 55 | + os.makedirs(index_dir) |
| 56 | + logger.info(f"Created directory for FAISS index: {index_dir}") |
| 57 | + |
| 58 | + vectorstore_faiss = FAISS.from_documents(chunked_documents, self.embeddings) |
| 59 | + vectorstore_faiss.save_local(self.index_path) |
| 60 | + logger.info(f"FAISS index created and saved to {self.index_path}.") |
| 61 | + except Exception as e: |
| 62 | + logger.error(f"Error creating and saving FAISS vector store: {e}", exc_info=True) |
| 63 | + raise HTTPException(status_code=500, detail="Error creating FAISS vector store") |
| 64 | + |
| 65 | + def load_vector_store(self): |
| 66 | + try: |
| 67 | + # Check if the FAISS index file exists before loading |
| 68 | + if not os.path.exists(self.index_path): |
| 69 | + logger.error(f"FAISS index file '{self.index_path}' not found. Creating a new index might be required.") |
| 70 | + raise HTTPException(status_code=404, detail="FAISS index not found") |
| 71 | + |
| 72 | + logger.info(f"Loading FAISS index from {self.index_path}...") |
| 73 | + return FAISS.load_local(self.index_path, self.embeddings, allow_dangerous_deserialization=True) |
| 74 | + except FileNotFoundError: |
| 75 | + logger.error(f"FAISS index file '{self.index_path}' not found.") |
| 76 | + raise HTTPException(status_code=404, detail="FAISS index not found") |
| 77 | + except Exception as e: |
| 78 | + logger.error(f"Error loading FAISS vector store: {e}", exc_info=True) |
| 79 | + raise HTTPException(status_code=500, detail="Error loading FAISS vector store") |
| 80 | + |
| 81 | +# LLM Service |
| 82 | +class LLMService: |
| 83 | + def __init__(self, model_id: str, client): |
| 84 | + self.model_id = model_id |
| 85 | + self.client = client |
| 86 | + |
| 87 | + def initialize_llm(self): |
| 88 | + try: |
| 89 | + logger.info(f"Initializing LLM with model ID: {self.model_id}") |
| 90 | + return BedrockLLM(model_id=self.model_id, client=self.client) |
| 91 | + except Exception as e: |
| 92 | + logger.error(f"Error initializing LLM: {e}", exc_info=True) |
| 93 | + raise HTTPException(status_code=500, detail="Error initializing LLM") |
| 94 | + |
| 95 | + def generate_response(self, llm, vectorstore_faiss, query: str): |
| 96 | + try: |
| 97 | + start_time = time.time() |
| 98 | + logger.info(f"Generating response for query: '{query}'") |
| 99 | + qa = RetrievalQA.from_chain_type( |
| 100 | + llm=llm, |
| 101 | + chain_type="stuff", |
| 102 | + retriever=vectorstore_faiss.as_retriever(search_type="similarity", search_kwargs={"k": 3}), |
| 103 | + return_source_documents=True, |
| 104 | + chain_type_kwargs={"prompt": PROMPT} |
| 105 | + ) |
| 106 | + result = qa.invoke({"query": query}) |
| 107 | + logger.info(f"Response generated in {time.time() - start_time:.2f} seconds.") |
| 108 | + return result['result'] |
| 109 | + except Exception as e: |
| 110 | + logger.error(f"Error generating LLM response: {e}", exc_info=True) |
| 111 | + raise HTTPException(status_code=500, detail="Error generating response") |
0 commit comments