From 6eceabf43eeb5c455674a1ad140453ea23127938 Mon Sep 17 00:00:00 2001 From: Adil Hafeez Date: Thu, 24 Oct 2024 17:45:04 -0700 Subject: [PATCH] fix more --- chatbot_ui/.vscode/launch.json | 5 +- chatbot_ui/app/arch_util.py | 20 ++ chatbot_ui/app/run.py | 83 +++++--- crates/common/src/common_types.rs | 3 + crates/common/src/consts.rs | 3 +- crates/prompt_gateway/src/hallucination.rs | 16 +- crates/prompt_gateway/src/http_context.rs | 204 ++++++++++---------- crates/prompt_gateway/src/stream_context.rs | 15 +- demos/llm_routing/arch_config.yaml | 5 - demos/llm_routing/docker-compose.yaml | 2 + 10 files changed, 204 insertions(+), 152 deletions(-) create mode 100644 chatbot_ui/app/arch_util.py diff --git a/chatbot_ui/.vscode/launch.json b/chatbot_ui/.vscode/launch.json index 6f81b218..947754a5 100644 --- a/chatbot_ui/.vscode/launch.json +++ b/chatbot_ui/.vscode/launch.json @@ -14,7 +14,9 @@ "console": "integratedTerminal", "env": { "LLM": "1", - "CHAT_COMPLETION_ENDPOINT": "http://localhost:10000/v1" + "CHAT_COMPLETION_ENDPOINT": "http://localhost:10000/v1", + "STREAMING": "True", + "ARCH_CONFIG": "../../demos/function_calling/arch_config.yaml" } }, { @@ -30,6 +32,7 @@ } }, { + "python": "${workspaceFolder}/venv/bin/python", "name": "chatbot-ui (llm) streaming", "cwd": "${workspaceFolder}/app", "type": "debugpy", diff --git a/chatbot_ui/app/arch_util.py b/chatbot_ui/app/arch_util.py new file mode 100644 index 00000000..567640e5 --- /dev/null +++ b/chatbot_ui/app/arch_util.py @@ -0,0 +1,20 @@ +import json + + +ARCH_STATE_HEADER = "x-arch-state" + + +def get_arch_messages(response_json): + arch_messages = [] + if response_json and "metadata" in response_json: + # load arch_state from metadata + arch_state_str = response_json.get("metadata", {}).get(ARCH_STATE_HEADER, "{}") + # parse arch_state into json object + arch_state = json.loads(arch_state_str) + # load messages from arch_state + arch_messages_str = arch_state.get("messages", "[]") + # parse messages into json object + arch_messages = json.loads(arch_messages_str) + # append messages from arch gateway to history + return arch_messages + return [] diff --git a/chatbot_ui/app/run.py b/chatbot_ui/app/run.py index 05a6a6db..d3c9dbd3 100644 --- a/chatbot_ui/app/run.py +++ b/chatbot_ui/app/run.py @@ -2,6 +2,7 @@ import json import os import logging import yaml +from arch_util import get_arch_messages import gradio as gr from typing import List, Optional, Tuple @@ -10,6 +11,8 @@ from dotenv import load_dotenv load_dotenv() +STREAM_RESPONSE = bool(os.getenv("STREAM_RESPOSE", True)) + logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s", @@ -20,7 +23,6 @@ log = logging.getLogger(__name__) CHAT_COMPLETION_ENDPOINT = os.getenv("CHAT_COMPLETION_ENDPOINT") log.info(f"CHAT_COMPLETION_ENDPOINT: {CHAT_COMPLETION_ENDPOINT}") -ARCH_STATE_HEADER = "x-arch-state" CSS_STYLE = """ .json-container { @@ -69,7 +71,7 @@ def convert_prompt_target_to_openai_format(target): def get_prompt_targets(): try: - with open("arch_config.yaml", "r") as file: + with open(os.getenv("ARCH_CONFIG", "arch_config.yaml"), "r") as file: config = yaml.safe_load(file) available_tools = [] @@ -105,48 +107,65 @@ def chat(query: Optional[str], conversation: Optional[List[Tuple[str, str]]], st temperature=1.0, # metadata=metadata, extra_headers=custom_headers, + stream=STREAM_RESPONSE, ) except Exception as e: log.info(e) # remove last user message in case of exception history.pop() - log.info("Error calling gateway API: {}".format(e.message)) - raise gr.Error("Error calling gateway API: {}".format(e.message)) + log.info("Error calling gateway API: {}".format(e)) + raise gr.Error("Error calling gateway API: {}".format(e)) - log.error(f"raw_response: {raw_response.text}") - response = raw_response.parse() + if STREAM_RESPONSE: + response = raw_response.parse() + history.append({"role": "assistant", "content": "", "model": ""}) + # for gradio UI we don't want to show raw tool calls and messages from developer application + # so we're filtering those out + history_view = [h for h in history if h["role"] != "tool" and "content" in h] - # extract arch_state from metadata and store it in gradio session state - # this state must be passed back to the gateway in the next request - response_json = json.loads(raw_response.text) - log.info(response_json) - if response_json and "metadata" in response_json: - # load arch_state from metadata - arch_state_str = response_json.get("metadata", {}).get(ARCH_STATE_HEADER, "{}") - # parse arch_state into json object - arch_state = json.loads(arch_state_str) - # load messages from arch_state - arch_messages_str = arch_state.get("messages", "[]") - # parse messages into json object - arch_messages = json.loads(arch_messages_str) - # append messages from arch gateway to history - for message in arch_messages: - history.append(message) + messages = [ + (history_view[i]["content"], history_view[i + 1]["content"]) + for i in range(0, len(history_view) - 1, 2) + ] - content = response.choices[0].message.content + for chunk in response: + if len(chunk.choices) > 0 and chunk.choices[0].delta.content is not None: + history[-1]["model"] = chunk.model + history[-1]["content"] = chunk.choices[0].delta.content + messages[-1] = ( + messages[-1][0], + messages[-1][1] + chunk.choices[0].delta.content, + ) + yield "", messages, state + else: + log.error(f"raw_response: {raw_response.text}") + response = raw_response.parse() - history.append({"role": "assistant", "content": content, "model": response.model}) + # extract arch_state from metadata and store it in gradio session state + # this state must be passed back to the gateway in the next request + response_json = json.loads(raw_response.text) + log.info(response_json) - # for gradio UI we don't want to show raw tool calls and messages from developer application - # so we're filtering those out - history_view = [h for h in history if h["role"] != "tool" and "content" in h] + arch_messages = get_arch_messages(response_json) + for arch_message in arch_messages: + history.append(arch_message) - messages = [ - (history_view[i]["content"], history_view[i + 1]["content"]) - for i in range(0, len(history_view) - 1, 2) - ] + content = response.choices[0].message.content - return "", messages, state + history.append( + {"role": "assistant", "content": content, "model": response.model} + ) + + # for gradio UI we don't want to show raw tool calls and messages from developer application + # so we're filtering those out + history_view = [h for h in history if h["role"] != "tool" and "content" in h] + + messages = [ + (history_view[i]["content"], history_view[i + 1]["content"]) + for i in range(0, len(history_view) - 1, 2) + ] + + yield "", messages, state def main(): diff --git a/crates/common/src/common_types.rs b/crates/common/src/common_types.rs index f9bf5921..57da1519 100644 --- a/crates/common/src/common_types.rs +++ b/crates/common/src/common_types.rs @@ -269,6 +269,9 @@ pub mod open_ai { .events .iter() .map(|response_chunk| { + if response_chunk.choices.is_empty() { + return "".to_string(); + } response_chunk.choices[0] .delta .content diff --git a/crates/common/src/consts.rs b/crates/common/src/consts.rs index 81df31f8..f8a8b847 100644 --- a/crates/common/src/consts.rs +++ b/crates/common/src/consts.rs @@ -25,4 +25,5 @@ pub const ARCH_INTERNAL_CLUSTER_NAME: &str = "arch_internal"; pub const ARCH_UPSTREAM_HOST_HEADER: &str = "x-arch-upstream"; pub const ARCH_LLM_UPSTREAM_LISTENER: &str = "arch_llm_listener"; pub const ARCH_MODEL_PREFIX: &str = "Arch"; -pub const HALLUCINATION_TEMPLATE: &str = "It seems I’m missing some information. Could you provide the following details "; +pub const HALLUCINATION_TEMPLATE: &str = + "It seems I’m missing some information. Could you provide the following details "; diff --git a/crates/prompt_gateway/src/hallucination.rs b/crates/prompt_gateway/src/hallucination.rs index b4621ec4..130f8723 100644 --- a/crates/prompt_gateway/src/hallucination.rs +++ b/crates/prompt_gateway/src/hallucination.rs @@ -1,9 +1,9 @@ use common::{ common_types::open_ai::Message, - consts::{ARCH_MODEL_PREFIX, USER_ROLE, HALLUCINATION_TEMPLATE}, + consts::{ARCH_MODEL_PREFIX, HALLUCINATION_TEMPLATE, USER_ROLE}, }; -pub fn extract_messages_for_hallucination(messages: &Vec) -> Vec { +pub fn extract_messages_for_hallucination(messages: &[Message]) -> Vec { let mut arch_assistant = false; let mut user_messages = Vec::new(); if messages.len() >= 2 { @@ -18,11 +18,11 @@ pub fn extract_messages_for_hallucination(messages: &Vec) -> Vec) -> Vec deserialized, - Err(e) => { - self.send_server_error( - ServerError::Deserialization(e), - Some(StatusCode::BAD_REQUEST), - ); - return Action::Pause; - } - }; + let deserialized_body: ChatCompletionsRequest = match serde_json::from_slice(&body_bytes) { + Ok(deserialized) => deserialized, + Err(e) => { + self.send_server_error( + ServerError::Deserialization(e), + Some(StatusCode::BAD_REQUEST), + ); + return Action::Pause; + } + }; self.arch_state = match deserialized_body.metadata { Some(ref metadata) => { @@ -113,11 +112,6 @@ impl HttpContext for StreamContext { }; self.streaming_response = deserialized_body.stream; - if deserialized_body.stream && deserialized_body.stream_options.is_none() { - deserialized_body.stream_options = Some(StreamOptions { - include_usage: true, - }); - } let last_user_prompt = match deserialized_body .messages @@ -238,105 +232,119 @@ impl HttpContext for StreamContext { ); if !self.is_chat_completions_request { - if let Some(body_str) = self - .get_http_response_body(0, body_size) - .and_then(|bytes| String::from_utf8(bytes).ok()) - { - debug!("recv [S={}] body_str={}", self.context_id, body_str); - } + debug!("non-streaming request"); return Action::Continue; } - if !end_of_stream { - return Action::Pause; - } + let body = if self.streaming_response { + let streaming_chunk = match self.get_http_response_body(0, body_size) { + Some(chunk) => chunk, + None => { + warn!( + "response body empy, chunk_start: {}, chunk_size: {}", + 0, body_size + ); + return Action::Continue; + } + }; - let body = self - .get_http_response_body(0, body_size) - .expect("cant get response body"); + if streaming_chunk.len() != body_size { + warn!( + "chunk size mismatch: read: {} != requested: {}", + streaming_chunk.len(), + body_size + ); + } + + streaming_chunk + } else { + debug!("non streaming response bytes read: 0:{}", body_size); + match self.get_http_response_body(0, body_size) { + Some(body) => body, + None => { + warn!("non streaming response body empty"); + return Action::Continue; + } + } + }; + + let body_utf8 = match String::from_utf8(body) { + Ok(body_utf8) => body_utf8, + Err(e) => { + debug!("could not convert to utf8: {}", e); + return Action::Continue; + } + }; if self.streaming_response { trace!("streaming response"); - } else { - trace!("non streaming response"); - let chat_completions_response: ChatCompletionsResponse = - match serde_json::from_slice(&body) { - Ok(de) => de, + + let chat_completions_chunk_response_events = + match ChatCompletionChunkResponseServerEvents::try_from(body_utf8.as_str()) { + Ok(response) => response, Err(e) => { - trace!( - "invalid response: {}, {}", - String::from_utf8_lossy(&body), - e + debug!( + "invalid streaming response: body str: {}, {:?}", + body_utf8, e ); return Action::Continue; } }; + debug!( + "parsed events: {}", + chat_completions_chunk_response_events.to_string() + ); + } else if let Some(tool_calls) = self.tool_calls.as_ref() { + if !tool_calls.is_empty() { + if self.arch_state.is_none() { + self.arch_state = Some(Vec::new()); + } - if chat_completions_response.usage.is_some() { - self.response_tokens += chat_completions_response - .usage - .as_ref() - .unwrap() - .completion_tokens; - } - - if let Some(tool_calls) = self.tool_calls.as_ref() { - if !tool_calls.is_empty() { - if self.arch_state.is_none() { - self.arch_state = Some(Vec::new()); + let mut data = serde_json::from_str(&body_utf8).unwrap(); + // use serde::Value to manipulate the json object and ensure that we don't lose any data + if let Value::Object(ref mut map) = data { + // serialize arch state and add to metadata + let metadata = map + .entry("metadata") + .or_insert(Value::Object(serde_json::Map::new())); + if metadata == &Value::Null { + *metadata = Value::Object(serde_json::Map::new()); } - let mut data = serde_json::from_slice(&body).unwrap(); - // use serde::Value to manipulate the json object and ensure that we don't lose any data - if let Value::Object(ref mut map) = data { - // serialize arch state and add to metadata - let metadata = map - .entry("metadata") - .or_insert(Value::Object(serde_json::Map::new())); - if metadata == &Value::Null { - *metadata = Value::Object(serde_json::Map::new()); - } - - // since arch gateway generates tool calls (using arch-fc) and calls upstream api to - // get response, we will send these back to developer so they can see the api response - // and tool call arch-fc generated - let fc_messages = vec![ - Message { - role: ASSISTANT_ROLE.to_string(), - content: None, - model: Some(ARCH_FC_MODEL_NAME.to_string()), - tool_calls: self.tool_calls.clone(), - tool_call_id: None, - }, - Message { - role: TOOL_ROLE.to_string(), - content: self.tool_call_response.clone(), - model: None, - tool_calls: None, - tool_call_id: Some(self.tool_calls.as_ref().unwrap()[0].id.clone()), - }, - ]; - let fc_messages_str = serde_json::to_string(&fc_messages).unwrap(); - let arch_state = HashMap::from([("messages".to_string(), fc_messages_str)]); - let arch_state_str = serde_json::to_string(&arch_state).unwrap(); - metadata.as_object_mut().unwrap().insert( - ARCH_STATE_HEADER.to_string(), - serde_json::Value::String(arch_state_str), - ); - let data_serialized = serde_json::to_string(&data).unwrap(); - debug!("archgw <= developer: {}", data_serialized); - self.set_http_response_body(0, body_size, data_serialized.as_bytes()); - }; - } + // since arch gateway generates tool calls (using arch-fc) and calls upstream api to + // get response, we will send these back to developer so they can see the api response + // and tool call arch-fc generated + let fc_messages = vec![ + Message { + role: ASSISTANT_ROLE.to_string(), + content: None, + model: Some(ARCH_FC_MODEL_NAME.to_string()), + tool_calls: self.tool_calls.clone(), + tool_call_id: None, + }, + Message { + role: TOOL_ROLE.to_string(), + content: self.tool_call_response.clone(), + model: None, + tool_calls: None, + tool_call_id: Some(self.tool_calls.as_ref().unwrap()[0].id.clone()), + }, + ]; + let fc_messages_str = serde_json::to_string(&fc_messages).unwrap(); + let arch_state = HashMap::from([("messages".to_string(), fc_messages_str)]); + let arch_state_str = serde_json::to_string(&arch_state).unwrap(); + metadata.as_object_mut().unwrap().insert( + ARCH_STATE_HEADER.to_string(), + serde_json::Value::String(arch_state_str), + ); + let data_serialized = serde_json::to_string(&data).unwrap(); + debug!("archgw <= developer: {}", data_serialized); + self.set_http_response_body(0, body_size, data_serialized.as_bytes()); + }; } } - trace!( - "recv [S={}] total_tokens={} end_stream={}", - self.context_id, - self.response_tokens, - end_of_stream - ); + trace!("recv [S={}] end_stream={}", self.context_id, end_of_stream); Action::Continue } diff --git a/crates/prompt_gateway/src/stream_context.rs b/crates/prompt_gateway/src/stream_context.rs index 6f4a36ea..ade29268 100644 --- a/crates/prompt_gateway/src/stream_context.rs +++ b/crates/prompt_gateway/src/stream_context.rs @@ -12,7 +12,12 @@ use common::common_types::{ }; use common::configuration::{Overrides, PromptGuards, PromptTarget}; use common::consts::{ - ARCH_FC_INTERNAL_HOST, ARCH_FC_MODEL_NAME, ARCH_FC_REQUEST_TIMEOUT_MS, ARCH_INTERNAL_CLUSTER_NAME, MESSAGES_KEY, ARCH_MODEL_PREFIX, ARCH_STATE_HEADER, ARCH_UPSTREAM_HOST_HEADER, ASSISTANT_ROLE, DEFAULT_EMBEDDING_MODEL, HALLUCINATION_TEMPLATE, DEFAULT_HALLUCINATED_THRESHOLD, DEFAULT_INTENT_MODEL, DEFAULT_PROMPT_TARGET_THRESHOLD, EMBEDDINGS_INTERNAL_HOST, HALLUCINATION_INTERNAL_HOST, REQUEST_ID_HEADER, SYSTEM_ROLE, TOOL_ROLE, USER_ROLE, ZEROSHOT_INTERNAL_HOST + ARCH_FC_INTERNAL_HOST, ARCH_FC_MODEL_NAME, ARCH_FC_REQUEST_TIMEOUT_MS, + ARCH_INTERNAL_CLUSTER_NAME, ARCH_MODEL_PREFIX, ARCH_STATE_HEADER, ARCH_UPSTREAM_HOST_HEADER, + ASSISTANT_ROLE, DEFAULT_EMBEDDING_MODEL, DEFAULT_HALLUCINATED_THRESHOLD, DEFAULT_INTENT_MODEL, + DEFAULT_PROMPT_TARGET_THRESHOLD, EMBEDDINGS_INTERNAL_HOST, HALLUCINATION_INTERNAL_HOST, + HALLUCINATION_TEMPLATE, MESSAGES_KEY, REQUEST_ID_HEADER, SYSTEM_ROLE, TOOL_ROLE, USER_ROLE, + ZEROSHOT_INTERNAL_HOST, }; use common::embeddings::{ CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse, @@ -66,9 +71,8 @@ pub struct StreamContext { pub tool_call_response: Option, pub arch_state: Option>, pub request_body_size: usize, - pub streaming_response: bool, pub user_prompt: Option, - pub response_tokens: usize, + pub streaming_response: bool, pub is_chat_completions_request: bool, pub chat_completions_request: Option, pub prompt_guards: Rc, @@ -99,7 +103,6 @@ impl StreamContext { request_body_size: 0, streaming_response: false, user_prompt: None, - response_tokens: 0, is_chat_completions_request: false, prompt_guards, overrides, @@ -323,9 +326,7 @@ impl StreamContext { if !keys_with_low_score.is_empty() { let response = - HALLUCINATION_TEMPLATE.to_string() - + &keys_with_low_score.join(", ") - + " ?"; + HALLUCINATION_TEMPLATE.to_string() + &keys_with_low_score.join(", ") + " ?"; let message = Message { role: ASSISTANT_ROLE.to_string(), content: Some(response), diff --git a/demos/llm_routing/arch_config.yaml b/demos/llm_routing/arch_config.yaml index c5839bf4..e99b9687 100644 --- a/demos/llm_routing/arch_config.yaml +++ b/demos/llm_routing/arch_config.yaml @@ -17,11 +17,6 @@ llm_providers: provider: openai model: gpt-4o - - name: ministral-8b - access_key: $MISTRAL_API_KEY - provider: mistral - model: ministral-8b-latest - - name: ministral-3b access_key: $MISTRAL_API_KEY provider: mistral diff --git a/demos/llm_routing/docker-compose.yaml b/demos/llm_routing/docker-compose.yaml index f8200977..1ce6963b 100644 --- a/demos/llm_routing/docker-compose.yaml +++ b/demos/llm_routing/docker-compose.yaml @@ -10,3 +10,5 @@ services: - CHAT_COMPLETION_ENDPOINT=http://host.docker.internal:12000/v1 extra_hosts: - "host.docker.internal:host-gateway" + volumes: + - ./arch_config.yaml:/app/arch_config.yaml