Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

EdgeCraft RAG UI bug fix #1189

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions EdgeCraftRAG/edgecraftrag/components/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def run(self, chat_request, retrieved_nodes, **kwargs):
repetition_penalty=chat_request.repetition_penalty,
)
self.llm().generate_kwargs = generate_kwargs
self.llm().max_new_tokens = chat_request.max_tokens
if chat_request.stream:

async def stream_generator():
Expand Down Expand Up @@ -99,8 +100,10 @@ def run_vllm(self, chat_request, retrieved_nodes, **kwargs):
max_tokens=chat_request.max_tokens,
model=model_name,
top_p=chat_request.top_p,
top_k=chat_request.top_k,
temperature=chat_request.temperature,
streaming=chat_request.stream,
repetition_penalty=chat_request.repetition_penalty,
)

if chat_request.stream:
Expand Down
2 changes: 1 addition & 1 deletion EdgeCraftRAG/ui/gradio/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ name: "default"

# Node parser
node_parser: "simple"
chunk_size: 192
chunk_size: 400
chunk_overlap: 48

# Indexer
Expand Down
2 changes: 1 addition & 1 deletion EdgeCraftRAG/ui/gradio/ecrag_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def create_update_pipeline(
weight=llm_weights,
),
),
retriever=api_schema.RetrieverIn(retriever_type=retriever, retriever_topk=vector_search_top_k),
retriever=api_schema.RetrieverIn(retriever_type=retriever, retrieve_topk=vector_search_top_k),
postprocessor=[
api_schema.PostProcessorIn(
processor_type=postprocessor[0],
Expand Down
71 changes: 25 additions & 46 deletions EdgeCraftRAG/ui/gradio/ecragui.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,15 +87,11 @@ async def bot(
top_k,
repetition_penalty,
max_tokens,
hide_full_prompt,
docs,
chunk_size,
chunk_overlap,
vector_search_top_k,
vector_search_top_n,
run_rerank,
search_method,
score_threshold,
vector_rerank_top_n,
):
"""Callback function for running chatbot on submit button click.

Expand All @@ -108,8 +104,21 @@ async def bot(
repetition_penalty: parameter for penalizing tokens based on how frequently they occur in the text.
conversation_id: unique conversation identifier.
"""
if history[-1][0] == "" or len(history[-1][0]) == 0:
yield history[:-1]
return

stream_opt = True
new_req = {"messages": history[-1][0], "stream": stream_opt, "max_tokens": max_tokens}
new_req = {
"messages": history[-1][0],
"stream": stream_opt,
"max_tokens": max_tokens,
"top_n": vector_rerank_top_n,
"temperature": temperature,
"top_p": top_p,
"top_k": top_k,
"repetition_penalty": repetition_penalty,
}
server_addr = f"http://{MEGA_SERVICE_HOST_IP}:{MEGA_SERVICE_PORT}"

# Async for streaming response
Expand Down Expand Up @@ -362,7 +371,7 @@ def get_pipeline_df():
choices=avail_llm_inference_type, label="LLM Inference Type", value="local"
)

with gr.Accordion("LLM Configuration", open=True):
with gr.Accordion("LLM Configuration", open=True) as accordion:
u_llm_model_id = gr.Dropdown(
choices=avail_llms,
value=cfg.llm_model_id,
Expand Down Expand Up @@ -393,6 +402,12 @@ def get_pipeline_df():
# RAG Settings Events
# -------------------
# Event handlers
def update_visibility(selected_value): # Accept the event argument, even if not used
if selected_value == "vllm":
return gr.Accordion(visible=False)
else:
return gr.Accordion(visible=True)

def show_pipeline_detail(evt: gr.SelectData):
# get selected pipeline id
# Dataframe: {'headers': '', 'data': [[x00, x01], [x10, x11]}
Expand Down Expand Up @@ -470,6 +485,8 @@ def create_update_pipeline(
return res, get_pipeline_df()

# Events
u_llm_infertype.change(update_visibility, inputs=u_llm_infertype, outputs=accordion)

u_pipelines.select(
show_pipeline_detail,
inputs=None,
Expand Down Expand Up @@ -735,39 +752,9 @@ def delete_file():
with gr.Row():
submit = gr.Button("Submit")
clear = gr.Button("Clear")
retriever_argument = gr.Accordion("Retriever Configuration", open=True)
retriever_argument = gr.Accordion("Retriever Configuration", open=False)
with retriever_argument:
with gr.Row():
with gr.Row():
do_rerank = gr.Checkbox(
value=True,
label="Rerank searching result",
interactive=True,
)
hide_context = gr.Checkbox(
value=True,
label="Hide searching result in prompt",
interactive=True,
)
with gr.Row():
search_method = gr.Dropdown(
["similarity_score_threshold", "similarity", "mmr"],
value=cfg.search_method,
label="Searching Method",
info="Method used to search vector store",
multiselect=False,
interactive=True,
)
with gr.Row():
score_threshold = gr.Slider(
0.01,
0.99,
value=cfg.score_threshold,
step=0.01,
label="Similarity Threshold",
info="Only working for 'similarity score threshold' method",
interactive=True,
)
with gr.Row():
vector_rerank_top_n = gr.Slider(
1,
Expand Down Expand Up @@ -811,15 +798,11 @@ def delete_file():
top_k,
repetition_penalty,
u_max_tokens,
hide_context,
docs,
u_chunk_size,
u_chunk_overlap,
u_vector_search_top_k,
vector_rerank_top_n,
do_rerank,
search_method,
score_threshold,
],
chatbot,
queue=True,
Expand All @@ -833,15 +816,11 @@ def delete_file():
top_k,
repetition_penalty,
u_max_tokens,
hide_context,
docs,
u_chunk_size,
u_chunk_overlap,
u_vector_search_top_k,
vector_rerank_top_n,
do_rerank,
search_method,
score_threshold,
],
chatbot,
queue=True,
Expand Down
Loading