diff --git a/EdgeCraftRAG/edgecraftrag/components/generator.py b/EdgeCraftRAG/edgecraftrag/components/generator.py index 02c8cec2b..3d81205fa 100644 --- a/EdgeCraftRAG/edgecraftrag/components/generator.py +++ b/EdgeCraftRAG/edgecraftrag/components/generator.py @@ -66,6 +66,7 @@ def run(self, chat_request, retrieved_nodes, **kwargs): repetition_penalty=chat_request.repetition_penalty, ) self.llm().generate_kwargs = generate_kwargs + self.llm().max_new_tokens = chat_request.max_tokens if chat_request.stream: async def stream_generator(): @@ -99,8 +100,10 @@ def run_vllm(self, chat_request, retrieved_nodes, **kwargs): max_tokens=chat_request.max_tokens, model=model_name, top_p=chat_request.top_p, + top_k=chat_request.top_k, temperature=chat_request.temperature, streaming=chat_request.stream, + repetition_penalty=chat_request.repetition_penalty, ) if chat_request.stream: diff --git a/EdgeCraftRAG/ui/gradio/default.yaml b/EdgeCraftRAG/ui/gradio/default.yaml index ad3718f0c..86f81cfb3 100644 --- a/EdgeCraftRAG/ui/gradio/default.yaml +++ b/EdgeCraftRAG/ui/gradio/default.yaml @@ -14,7 +14,7 @@ name: "default" # Node parser node_parser: "simple" -chunk_size: 192 +chunk_size: 400 chunk_overlap: 48 # Indexer diff --git a/EdgeCraftRAG/ui/gradio/ecrag_client.py b/EdgeCraftRAG/ui/gradio/ecrag_client.py index 7a58ff720..070b3e3f4 100644 --- a/EdgeCraftRAG/ui/gradio/ecrag_client.py +++ b/EdgeCraftRAG/ui/gradio/ecrag_client.py @@ -67,7 +67,7 @@ def create_update_pipeline( weight=llm_weights, ), ), - retriever=api_schema.RetrieverIn(retriever_type=retriever, retriever_topk=vector_search_top_k), + retriever=api_schema.RetrieverIn(retriever_type=retriever, retrieve_topk=vector_search_top_k), postprocessor=[ api_schema.PostProcessorIn( processor_type=postprocessor[0], diff --git a/EdgeCraftRAG/ui/gradio/ecragui.py b/EdgeCraftRAG/ui/gradio/ecragui.py index 23a5286de..211c1aa2b 100644 --- a/EdgeCraftRAG/ui/gradio/ecragui.py +++ b/EdgeCraftRAG/ui/gradio/ecragui.py @@ -87,15 +87,11 @@ async def bot( top_k, repetition_penalty, max_tokens, - hide_full_prompt, docs, chunk_size, chunk_overlap, vector_search_top_k, - vector_search_top_n, - run_rerank, - search_method, - score_threshold, + vector_rerank_top_n, ): """Callback function for running chatbot on submit button click. @@ -108,8 +104,21 @@ async def bot( repetition_penalty: parameter for penalizing tokens based on how frequently they occur in the text. conversation_id: unique conversation identifier. """ + if history[-1][0] == "" or len(history[-1][0]) == 0: + yield history[:-1] + return + stream_opt = True - new_req = {"messages": history[-1][0], "stream": stream_opt, "max_tokens": max_tokens} + new_req = { + "messages": history[-1][0], + "stream": stream_opt, + "max_tokens": max_tokens, + "top_n": vector_rerank_top_n, + "temperature": temperature, + "top_p": top_p, + "top_k": top_k, + "repetition_penalty": repetition_penalty, + } server_addr = f"http://{MEGA_SERVICE_HOST_IP}:{MEGA_SERVICE_PORT}" # Async for streaming response @@ -362,7 +371,7 @@ def get_pipeline_df(): choices=avail_llm_inference_type, label="LLM Inference Type", value="local" ) - with gr.Accordion("LLM Configuration", open=True): + with gr.Accordion("LLM Configuration", open=True) as accordion: u_llm_model_id = gr.Dropdown( choices=avail_llms, value=cfg.llm_model_id, @@ -393,6 +402,12 @@ def get_pipeline_df(): # RAG Settings Events # ------------------- # Event handlers + def update_visibility(selected_value): # Accept the event argument, even if not used + if selected_value == "vllm": + return gr.Accordion(visible=False) + else: + return gr.Accordion(visible=True) + def show_pipeline_detail(evt: gr.SelectData): # get selected pipeline id # Dataframe: {'headers': '', 'data': [[x00, x01], [x10, x11]} @@ -470,6 +485,8 @@ def create_update_pipeline( return res, get_pipeline_df() # Events + u_llm_infertype.change(update_visibility, inputs=u_llm_infertype, outputs=accordion) + u_pipelines.select( show_pipeline_detail, inputs=None, @@ -735,39 +752,9 @@ def delete_file(): with gr.Row(): submit = gr.Button("Submit") clear = gr.Button("Clear") - retriever_argument = gr.Accordion("Retriever Configuration", open=True) + retriever_argument = gr.Accordion("Retriever Configuration", open=False) with retriever_argument: with gr.Row(): - with gr.Row(): - do_rerank = gr.Checkbox( - value=True, - label="Rerank searching result", - interactive=True, - ) - hide_context = gr.Checkbox( - value=True, - label="Hide searching result in prompt", - interactive=True, - ) - with gr.Row(): - search_method = gr.Dropdown( - ["similarity_score_threshold", "similarity", "mmr"], - value=cfg.search_method, - label="Searching Method", - info="Method used to search vector store", - multiselect=False, - interactive=True, - ) - with gr.Row(): - score_threshold = gr.Slider( - 0.01, - 0.99, - value=cfg.score_threshold, - step=0.01, - label="Similarity Threshold", - info="Only working for 'similarity score threshold' method", - interactive=True, - ) with gr.Row(): vector_rerank_top_n = gr.Slider( 1, @@ -811,15 +798,11 @@ def delete_file(): top_k, repetition_penalty, u_max_tokens, - hide_context, docs, u_chunk_size, u_chunk_overlap, u_vector_search_top_k, vector_rerank_top_n, - do_rerank, - search_method, - score_threshold, ], chatbot, queue=True, @@ -833,15 +816,11 @@ def delete_file(): top_k, repetition_penalty, u_max_tokens, - hide_context, docs, u_chunk_size, u_chunk_overlap, u_vector_search_top_k, vector_rerank_top_n, - do_rerank, - search_method, - score_threshold, ], chatbot, queue=True,