diff --git a/.gitignore b/.gitignore new file mode 100644 index 000000000..4d76c44b3 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +__pycache__ +.venv +secrets.toml \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 4beaed2ee..fc44e987d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -streamlit +streamlit-nightly openai llama-index nltk diff --git a/retrieval_handler.py b/retrieval_handler.py new file mode 100644 index 000000000..aa48ab34e --- /dev/null +++ b/retrieval_handler.py @@ -0,0 +1,67 @@ +from typing import Any, Dict, List, Optional + +from llama_index.callbacks.base import BaseCallbackHandler +from llama_index.callbacks.schema import CBEventType +import streamlit as st + +class StreamlitRetrievalHandler(BaseCallbackHandler): + """Callback handler for writing retrieval results to Streamlit.""" + + def __init__( + self, + container = None, + event_starts_to_ignore: Optional[List[CBEventType]] = None, + event_ends_to_ignore: Optional[List[CBEventType]] = None, + verbose: bool = False, + ) -> None: + self._container = container + + super().__init__( + event_starts_to_ignore=event_starts_to_ignore or [], + event_ends_to_ignore=event_ends_to_ignore or [], + ) + + def set_container(self, container): + self._container = container + + def start_trace(self, trace_id: Optional[str] = None) -> None: + return + + def end_trace( + self, + trace_id: Optional[str] = None, + trace_map: Optional[Dict[str, List[str]]] = None, + ) -> None: + return + + def on_event_start( + self, + event_type: CBEventType, + payload: Optional[Dict[str, Any]] = None, + event_id: str = "", + **kwargs: Any, + ) -> str: + if ( + event_type in (CBEventType.QUERY) + and event_type not in self.event_starts_to_ignore + and payload is not None + ): + self._results = st.status("Gathering context") + return event_id + + def on_event_end( + self, + event_type: CBEventType, + payload: Optional[Dict[str, Any]] = None, + event_id: str = "", + **kwargs: Any, + ) -> None: + if ( + event_type in (CBEventType.RETRIEVE) + and event_type not in self.event_ends_to_ignore + and payload is not None + ): + for idx, node in enumerate(payload["nodes"]): + self._results.write(f"**Node {idx}: Score: {node.score}**") + self._results.write(node.node.text) + self._results.update(state="complete") diff --git a/streamlit_app.py b/streamlit_app.py index 0dbc35ce9..502d33a56 100644 --- a/streamlit_app.py +++ b/streamlit_app.py @@ -1,8 +1,9 @@ -import streamlit as st -from llama_index import VectorStoreIndex, ServiceContext, Document -from llama_index.llms import OpenAI import openai -from llama_index import SimpleDirectoryReader +from llama_index import VectorStoreIndex, ServiceContext, SimpleDirectoryReader +from llama_index.callbacks import CallbackManager +from llama_index.llms import OpenAI +import streamlit as st +from retrieval_handler import StreamlitRetrievalHandler st.set_page_config(page_title="Chat with the Streamlit docs, powered by LlamaIndex", page_icon="🦙", layout="centered", initial_sidebar_state="auto", menu_items=None) openai.api_key = st.secrets.openai_key @@ -13,18 +14,21 @@ st.session_state.messages = [ {"role": "assistant", "content": "Ask me a question about Streamlit's open-source Python library!"} ] +st_cb = StreamlitRetrievalHandler() @st.cache_resource(show_spinner=False) def load_data(): with st.spinner(text="Loading and indexing the Streamlit docs – hang tight! This should take 1-2 minutes."): reader = SimpleDirectoryReader(input_dir="./data", recursive=True) docs = reader.load_data() - service_context = ServiceContext.from_defaults(llm=OpenAI(model="gpt-3.5-turbo", temperature=0.5, system_prompt="You are an expert on the Streamlit Python library and your job is to answer technical questions. Assume that all questions are related to the Streamlit Python library. Keep your answers technical and based on facts – do not hallucinate features.")) + service_context = ServiceContext.from_defaults( + llm=OpenAI(model="gpt-3.5-turbo", temperature=0.5, system_prompt="You are an expert on the Streamlit Python library and your job is to answer technical questions. Assume that all questions are related to the Streamlit Python library. Keep your answers technical and based on facts – do not hallucinate features."), + callback_manager=CallbackManager([st_cb]), + ) index = VectorStoreIndex.from_documents(docs, service_context=service_context) return index index = load_data() -# chat_engine = index.as_chat_engine(chat_mode="condense_question", verbose=True, system_prompt="You are an expert on the Streamlit Python library and your job is to answer technical questions. Assume that all questions are related to the Streamlit Python library. Keep your answers technical and based on facts – do not hallucinate features.") chat_engine = index.as_chat_engine(chat_mode="condense_question", verbose=True) if prompt := st.chat_input("Your question"): # Prompt for user input and save to chat history