From eaa99259ad71db8a386a5384d8860a7781c7a01b Mon Sep 17 00:00:00 2001 From: Adil Hafeez Date: Mon, 28 Oct 2024 20:55:58 -0700 Subject: [PATCH] fix some bugs --- chatbot_ui/app/run.py | 23 ++++++++++--------- crates/prompt_gateway/src/stream_context.rs | 6 ++++- demos/function_calling/api_server/app/main.py | 2 +- model_server/app/main.py | 4 ++-- 4 files changed, 20 insertions(+), 15 deletions(-) diff --git a/chatbot_ui/app/run.py b/chatbot_ui/app/run.py index b0d5acc6..ead9b0a8 100644 --- a/chatbot_ui/app/run.py +++ b/chatbot_ui/app/run.py @@ -87,7 +87,7 @@ def get_prompt_targets(): return None -def chat(query: Optional[str], conversation: Optional[List[Tuple[str, str]]], state): +def chat(query: Optional[str], messages: Optional[List[Tuple[str, str]]], state): if "history" not in state: state["history"] = [] @@ -119,19 +119,19 @@ def chat(query: Optional[str], conversation: Optional[List[Tuple[str, str]]], st if STREAM_RESPONSE: response = raw_response.parse() history.append({"role": "assistant", "content": "", "model": ""}) + messages.append((query, "")) # for gradio UI we don't want to show raw tool calls and messages from developer application # so we're filtering those out history_view = [h for h in history if h["role"] != "tool" and "content" in h] - messages = [ - (history_view[i]["content"], history_view[i + 1]["content"]) - for i in range(0, len(history_view) - 1, 2) - ] - for chunk in response: + print("chunk: " + str(chunk.to_dict())) if len(chunk.choices) > 0: if chunk.choices[0].delta.role: + print("role (hist): " + chunk.choices[0].delta.role) + print("role (resp): " + chunk.choices[0].delta.role) if history[-1]["role"] != chunk.choices[0].delta.role: + print("creating new history item: " + str(chunk.choices[0])) history.append( { "role": chunk.choices[0].delta.role, @@ -151,11 +151,12 @@ def chat(query: Optional[str], conversation: Optional[List[Tuple[str, str]]], st if chunk.choices[0].delta.tool_calls: history[-1]["tool_calls"] = chunk.choices[0].delta.tool_calls - if chunk.model and chunk.choices[0].delta.content: - messages[-1] = ( - messages[-1][0], - messages[-1][1] + chunk.choices[0].delta.content, - ) + if history[-1]["role"] != "tool": + if chunk.model and chunk.choices[0].delta.content: + messages[-1] = ( + messages[-1][0], + messages[-1][1] + chunk.choices[0].delta.content, + ) yield "", messages, state else: log.error(f"raw_response: {raw_response.text}") diff --git a/crates/prompt_gateway/src/stream_context.rs b/crates/prompt_gateway/src/stream_context.rs index 5d79d181..4bbd3fa6 100644 --- a/crates/prompt_gateway/src/stream_context.rs +++ b/crates/prompt_gateway/src/stream_context.rs @@ -900,7 +900,11 @@ impl StreamContext { // don't send tools message and api response to chat gpt for m in callout_context.request_body.messages.iter() { - if m.role == TOOL_ROLE || m.content.is_none() { + // don't send api response and tool calls to upstream LLMs + if m.role == TOOL_ROLE + || m.content.is_none() + || (m.tool_calls.is_some() && !m.tool_calls.as_ref().unwrap().is_empty()) + { continue; } messages.push(m.clone()); diff --git a/demos/function_calling/api_server/app/main.py b/demos/function_calling/api_server/app/main.py index e87a3a21..a69c75d1 100644 --- a/demos/function_calling/api_server/app/main.py +++ b/demos/function_calling/api_server/app/main.py @@ -71,7 +71,7 @@ class DefaultTargetRequest(BaseModel): @app.post("/default_target") async def default_target(req: DefaultTargetRequest, res: Response): - logger.info(f"Received arch_messages: {req.messages}") + logger.info(f"Received messages: {req.messages}") resp = { "choices": [ { diff --git a/model_server/app/main.py b/model_server/app/main.py index 93d6217b..a8d312d7 100644 --- a/model_server/app/main.py +++ b/model_server/app/main.py @@ -186,8 +186,8 @@ async def hallucination(req: HallucinationRequest, res: Response): start_time = time.perf_counter() classifier = zero_shot_model["pipeline"] - if "arch_messages" in req.parameters: - req.parameters.pop("arch_messages") + if "messages" in req.parameters: + req.parameters.pop("messages") candidate_labels = {f"{k} is {v}": k for k, v in req.parameters.items()}