Skip to content

Commit

Permalink
Untangle the code and stream OpenAI answer character by character
Browse files Browse the repository at this point in the history
  • Loading branch information
IvanRublev committed Jun 21, 2024
1 parent 7245822 commit 81a7c67
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 70 deletions.
159 changes: 89 additions & 70 deletions src/llm_pdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,52 +19,64 @@


def llm_pdf_app():
if "es" not in st.session_state:
st.session_state["es"] = Elasticsearch(Settings.elastic_url)

logger.info("UI loop")

icon = "💬"
st.set_page_config(page_title=Settings.app_description, page_icon=icon, layout="centered")
st.title(icon + " " + Settings.app_description, anchor="home")

if "input_disabled" not in st.session_state:
st.session_state["input_disabled"] = True

if "first_last_doc" not in st.session_state:
st.session_state["first_last_doc"] = None
if "pdf_hash" not in st.session_state or not st.session_state["pdf_hash"]:
st.subheader("👈 Please upload the pdf file first.")

# Display chat messages from the conversation history
if "messages" not in st.session_state:
st.session_state["messages"] = []

if "es" not in st.session_state:
st.session_state["es"] = Elasticsearch(Settings.elastic_url)

if "pdf_hash" not in st.session_state or not st.session_state["pdf_hash"]:
st.subheader("👈 Please upload the pdf file first.")

if "sample_question" not in st.session_state:
st.session_state["sample_question"] = None

for message in st.session_state.messages:
with st.chat_message(message["role"]):
if isinstance(message["content"], dict):
st.markdown(message["content"]["markdown"])

def set_sample_question(question):
st.session_state.sample_question = question

for question in message["content"]["sample_questions"]:
st.button(question, on_click=set_sample_question, args=[question])
else:
st.markdown(message["content"])

# Show the control for user input
if "input_disabled" not in st.session_state:
st.session_state["input_disabled"] = True

user_prompt = st.chat_input("Your question", disabled=st.session_state.input_disabled)

was_sample_question = False
# The sample question selected by the button overrides the user's input
if st.session_state.sample_question:
# we're in rerun, the messages history is gone, let's display progress
was_sample_question = True
user_prompt = st.session_state.sample_question
st.session_state.sample_question = None

if "first_last_doc" not in st.session_state:
st.session_state["first_last_doc"] = None

if user_prompt:
st.session_state.sample_question = None
elastic_docs = st.session_state.first_last_doc + _query_elastic(user_prompt)
prompt = _build_llm_prompt(user_prompt, st.session_state.messages, elastic_docs)
if was_sample_question:
with st.spinner("Thinking..."):
answer = _request_llm(prompt)
else:
answer = _request_llm(prompt)
with st.chat_message("user"):
st.markdown(user_prompt)

with st.chat_message("assistant"):
answer = st.write_stream(
_stream_rag(user_prompt, st.session_state.messages, st.session_state.first_last_doc)
)
st.session_state.messages.append({"role": "user", "content": user_prompt})
st.session_state.messages.append({"role": "assistant", "content": answer})

# Sidebar
# Sidebar to load pdf
st.sidebar.title("Your PDF")

_remove_chunks_from_elastic_on_exit()
Expand All @@ -91,7 +103,7 @@ def llm_pdf_app():
st.session_state.first_last_doc = None
st.session_state.messages = []
st.cache_data.clear()
st.session_state["input_disabled"] = True
st.session_state.input_disabled = True
st.rerun()

logger.info(f"A file was of size: {pdf_size} was written to: {pdf_path}")
Expand All @@ -102,46 +114,47 @@ def llm_pdf_app():
logger.info(f"The file was splitted into: {chunks_count} chunks")

def on_finish_fn():
# message will be displayed on the following rerun
st.session_state.messages.append(
{"role": "assistant", "content": "Thank you for the uploading! How I can help you?"}
)
st.session_state.messages.append(
{
"role": "assistant",
"content": {
"markdown": "Sample questions I can answer:",
"sample_questions": Settings.sample_questions,
},
}
)

docs = _docs_from_chunks(chunks, pdf_hash)

if docs:
if not st.session_state.first_last_doc:
if docs[0] != docs[-1]:
st.session_state.first_last_doc = [docs[0], docs[-1]]
else:
st.session_state.first_last_doc = [docs[0]]
# we split the docs, first and last one we keep in memory to always put in llm,
# rest upload we to elstic for fruther retrival
if docs[0] != docs[-1]:
st.session_state.first_last_doc = [docs[0], docs[-1]]
docs.pop(-1)
docs.pop(0)
else:
st.session_state.first_last_doc = [docs[0]]
docs.pop(0)
_upload_docs_to_elastic(docs, pdf_hash, on_finish_fn)

if st.session_state.input_disabled:
st.session_state["input_disabled"] = False
st.session_state.input_disabled = False
st.rerun()
else:
# no file uploaded, cleanup associated attributes
if not st.session_state.input_disabled:
st.session_state.input_disabled = True
st.session_state.pdf_hash = None
st.session_state.first_last_doc = None
st.session_state.messages = []
st.cache_data.clear()
st.session_state["input_disabled"] = True
st.rerun()

# Display chat messages from history on app rerun
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])

if len(st.session_state.messages) == 1:
with st.chat_message("assistant"):
st.markdown("Sample questions:")
questions = ["What this document is about?", "How many pages are in the document?"]
for question in questions:
if st.button(question):
st.session_state.sample_question = question
st.rerun()


def _write_tmp_file(uploaded_file):
pdf_path = None
Expand Down Expand Up @@ -243,23 +256,6 @@ def _upload_docs_to_elastic(_docs, cache_data_pdf_hash, _on_finish_fn):
on_finish_fn()


def _query_elastic(question):
search_query = {
"size": 5,
"query": {"bool": {"must": {"multi_match": {"query": question, "fields": ["text"], "type": "best_fields"}}}},
}

es = st.session_state.es
index_name = _index_name()
response = es.search(index=index_name, body=search_query)

chunks = []
for hit in response["hits"]["hits"]:
chunks.append(hit["_source"])

return chunks


def _remove_chunks_from_elastic_on_exit():
# from https://discuss.streamlit.io/t/detecting-user-exit-browser-tab-closed-session-end/62066
thread = threading.Timer(interval=2, function=_remove_chunks_from_elastic_on_exit)
Expand Down Expand Up @@ -288,6 +284,37 @@ def _remove_chunks_from_elastic_on_exit():
return


def _stream_rag(prompt, messages_history, first_last_doc):
def prompt_fn(prompt):
docs = first_last_doc + _query_elastic(prompt)
llm_prompt = _build_llm_prompt(prompt, messages_history, docs)
logger.info(f"Prompt for LLM: {llm_prompt}")
return llm_prompt

model = ChatOpenAI(api_key=Settings.openai_api_key, model="gpt-3.5-turbo")
parser = StrOutputParser()
chain = prompt_fn | model | parser
logger.info(f"User prompt: {prompt}")
return chain.stream(prompt)


def _query_elastic(question):
search_query = {
"size": 5,
"query": {"bool": {"must": {"multi_match": {"query": question, "fields": ["text"], "type": "best_fields"}}}},
}

es = st.session_state.es
index_name = _index_name()
response = es.search(index=index_name, body=search_query)

chunks = []
for hit in response["hits"]["hits"]:
chunks.append(hit["_source"])

return chunks


def _build_llm_prompt(user_prompt, messages_history, docs):
template = PromptTemplate.from_template(Settings.llm_prompt_template)

Expand All @@ -314,14 +341,6 @@ def _build_llm_prompt(user_prompt, messages_history, docs):
return template.format(context=context, question=user_prompt)


def _request_llm(prompt):
model = ChatOpenAI(api_key=Settings.openai_api_key, model="gpt-3.5-turbo")
parser = StrOutputParser()
chain = model | parser
logger.info(prompt)
return chain.invoke(prompt)


def _index_name():
session_id = _session_id()
return f"chunks_{session_id}"
Expand Down
1 change: 1 addition & 0 deletions src/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class Settings:
[QUESTION]
{question}
"""
sample_questions = ["What this document is about?", "How many pages are in the document?"]

# From pyproject.toml

Expand Down

0 comments on commit 81a7c67

Please sign in to comment.