diff --git a/crates/common/src/common_types.rs b/crates/common/src/common_types.rs index 4ca9dbeb..38e1afa8 100644 --- a/crates/common/src/common_types.rs +++ b/crates/common/src/common_types.rs @@ -42,6 +42,8 @@ pub mod open_ai { use serde::{ser::SerializeMap, Deserialize, Serialize}; use serde_yaml::Value; + use crate::consts::{ARCH_FC_MODEL_NAME, ASSISTANT_ROLE}; + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ChatCompletionsRequest { #[serde(default)] @@ -242,6 +244,28 @@ pub mod open_ai { pub metadata: Option>, } + // create constructor for ChatCompletionsResponse + impl ChatCompletionsResponse { + pub fn new(message: String) -> Self { + ChatCompletionsResponse { + choices: vec![Choice { + message: Message { + role: ASSISTANT_ROLE.to_string(), + content: Some(message), + model: Some(ARCH_FC_MODEL_NAME.to_string()), + tool_calls: None, + tool_call_id: None, + }, + index: 0, + finish_reason: "done".to_string(), + }], + usage: None, + model: ARCH_FC_MODEL_NAME.to_string(), + metadata: None, + } + } + } + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Usage { pub completion_tokens: usize, @@ -254,6 +278,24 @@ pub mod open_ai { pub choices: Vec, } + impl ChatCompletionStreamResponse { + pub fn new(response: Option, role: Option) -> Self { + ChatCompletionStreamResponse { + model: Some(ARCH_FC_MODEL_NAME.to_string()), + choices: vec![ChunkChoice { + delta: Delta { + role, + content: response, + tool_calls: None, + model: None, + tool_call_id: None, + }, + finish_reason: None, + }], + } + } + } + #[derive(Debug, thiserror::Error)] pub enum ChatCompletionChunkResponseError { #[error("failed to deserialize")] diff --git a/crates/common/src/consts.rs b/crates/common/src/consts.rs index 6922b34b..f782cf99 100644 --- a/crates/common/src/consts.rs +++ b/crates/common/src/consts.rs @@ -27,4 +27,4 @@ pub const ARCH_UPSTREAM_HOST_HEADER: &str = "x-arch-upstream"; pub const ARCH_LLM_UPSTREAM_LISTENER: &str = "arch_llm_listener"; pub const ARCH_MODEL_PREFIX: &str = "Arch"; pub const HALLUCINATION_TEMPLATE: &str = - "It seems I’m missing some information. Could you provide the following details "; + "It seems I'm missing some information. Could you provide the following details "; diff --git a/crates/prompt_gateway/src/stream_context.rs b/crates/prompt_gateway/src/stream_context.rs index 8899018f..bd5c13ff 100644 --- a/crates/prompt_gateway/src/stream_context.rs +++ b/crates/prompt_gateway/src/stream_context.rs @@ -2,9 +2,9 @@ use crate::filter_context::{EmbeddingsStore, WasmMetrics}; use crate::hallucination::extract_messages_for_hallucination; use acap::cos; use common::common_types::open_ai::{ - ArchState, ChatCompletionTool, ChatCompletionsRequest, ChatCompletionsResponse, Choice, - FunctionDefinition, FunctionParameter, FunctionParameters, Message, ParameterType, ToolCall, - ToolType, + ArchState, ChatCompletionStreamResponse, ChatCompletionTool, ChatCompletionsRequest, + ChatCompletionsResponse, ChunkChoice, Delta, FunctionDefinition, FunctionParameter, + FunctionParameters, Message, ParameterType, ToolCall, ToolType, }; use common::common_types::{ EmbeddingType, HallucinationClassificationRequest, HallucinationClassificationResponse, @@ -303,16 +303,16 @@ impl StreamContext { body: Vec, callout_context: StreamCallContext, ) { - let boyd_str = String::from_utf8(body).expect("could not convert body to string"); - debug!("archgw <= hallucination response: {}", boyd_str); + let body_str = String::from_utf8(body).expect("could not convert body to string"); + debug!("archgw <= hallucination response: {}", body_str); let hallucination_response: HallucinationClassificationResponse = - match serde_json::from_str(boyd_str.as_str()) { + match serde_json::from_str(body_str.as_str()) { Ok(hallucination_response) => hallucination_response, Err(e) => { warn!( "error deserializing hallucination response: {}, body: {}", e, - boyd_str.as_str() + body_str.as_str() ); return self.send_server_error(ServerError::Deserialization(e), None); } @@ -331,34 +331,31 @@ impl StreamContext { if !keys_with_low_score.is_empty() { let response = HALLUCINATION_TEMPLATE.to_string() + &keys_with_low_score.join(", ") + " ?"; - let message = Message { - role: ASSISTANT_ROLE.to_string(), - content: Some(response), - model: Some(ARCH_FC_MODEL_NAME.to_string()), - tool_calls: None, - tool_call_id: None, - }; - let chat_completion_response = ChatCompletionsResponse { - choices: vec![Choice { - message, - index: 0, - finish_reason: "done".to_string(), - }], - usage: None, - model: ARCH_FC_MODEL_NAME.to_string(), - metadata: None, - }; + let response_str = if self.streaming_response { + let chunks = [ + ChatCompletionStreamResponse::new(None, Some(ASSISTANT_ROLE.to_string())), + ChatCompletionStreamResponse::new(Some(response), None), + ]; - trace!("hallucination response: {:?}", chat_completion_response); + let mut response_str = String::new(); + for chunk in chunks.iter() { + response_str.push_str("data: "); + response_str.push_str(&serde_json::to_string(&chunk).unwrap()); + response_str.push_str("\n\n"); + } + response_str + } else { + let chat_completion_response = ChatCompletionsResponse::new(response); + serde_json::to_string(&chat_completion_response).unwrap() + }; + debug!("hallucination response: {:?}", response_str); + // make sure on_http_response_body does not attach tool calls and tool response to the response + self.tool_calls = None; self.send_http_response( StatusCode::OK.as_u16().into(), vec![("Powered-By", "Katanemo")], - Some( - serde_json::to_string(&chat_completion_response) - .unwrap() - .as_bytes(), - ), + Some(response_str.as_bytes()), ); } else { // not a hallucination, resume the flow @@ -948,7 +945,7 @@ impl StreamContext { self.get_embeddings(callout_context); } - pub fn default_target_handler(&self, body: Vec, callout_context: StreamCallContext) { + pub fn default_target_handler(&self, body: Vec, mut callout_context: StreamCallContext) { let prompt_target = self .prompt_targets .get(callout_context.prompt_target_name.as_ref().unwrap()) @@ -956,8 +953,48 @@ impl StreamContext { .clone(); // check if the default target should be dispatched to the LLM provider - if !prompt_target.auto_llm_dispatch_on_response.unwrap_or(false) { - let default_target_response_str = String::from_utf8(body).unwrap(); + if !prompt_target + .auto_llm_dispatch_on_response + .unwrap_or_default() + { + let default_target_response_str = if self.streaming_response { + let chat_completion_response = + serde_json::from_slice::(&body).unwrap(); + + let chunk_role_message = ChatCompletionStreamResponse { + model: Some(chat_completion_response.model.clone()), + choices: vec![ChunkChoice { + delta: Delta { + role: Some(USER_ROLE.to_string()), + content: None, + tool_calls: None, + model: None, + tool_call_id: None, + }, + finish_reason: None, + }], + }; + + let chat_completion_stream_response = ChatCompletionStreamResponse { + model: Some(chat_completion_response.model), + choices: vec![ChunkChoice { + delta: Delta { + role: None, + content: chat_completion_response.choices[0].message.content.clone(), + tool_calls: None, + model: None, + tool_call_id: None, + }, + finish_reason: None, + }], + }; + let chunk_role = serde_json::to_string(&chunk_role_message).unwrap(); + let chunk_data = serde_json::to_string(&chat_completion_stream_response).unwrap(); + format!("data: {}\n\ndata: {}\n\n", chunk_role, chunk_data) + } else { + String::from_utf8(body).unwrap() + }; + self.send_http_response( StatusCode::OK.as_u16().into(), vec![("Powered-By", "Katanemo")], @@ -965,20 +1002,20 @@ impl StreamContext { ); return; } + let chat_completions_resp: ChatCompletionsResponse = match serde_json::from_slice(&body) { Ok(chat_completions_resp) => chat_completions_resp, Err(e) => { - warn!("error deserializing default target response: {}", e); + warn!( + "error deserializing default target response: {}, body str: {}", + e, + String::from_utf8(body).unwrap() + ); return self.send_server_error(ServerError::Deserialization(e), None); } }; - let api_resp = chat_completions_resp.choices[0] - .message - .content - .as_ref() - .unwrap(); - let mut messages = callout_context.request_body.messages; + let mut messages = Vec::new(); // add system prompt match prompt_target.system_prompt.as_ref() { None => {} @@ -994,13 +1031,24 @@ impl StreamContext { } } + messages.append(&mut callout_context.request_body.messages); + + let api_resp = chat_completions_resp.choices[0] + .message + .content + .as_ref() + .unwrap(); + + let user_message = messages.pop().unwrap(); + let message = format!("{}\ncontext: {}", user_message.content.unwrap(), api_resp); messages.push(Message { role: USER_ROLE.to_string(), - content: Some(api_resp.clone()), + content: Some(message), model: None, tool_calls: None, tool_call_id: None, }); + let chat_completion_request = ChatCompletionsRequest { model: self .chat_completions_request @@ -1014,6 +1062,7 @@ impl StreamContext { stream_options: callout_context.request_body.stream_options, metadata: None, }; + let json_resp = serde_json::to_string(&chat_completion_request).unwrap(); debug!("archgw => (default target) llm request: {}", json_resp); self.set_http_request_body(0, self.request_body_size, json_resp.as_bytes()); diff --git a/demos/function_calling/api_server/app/main.py b/demos/function_calling/api_server/app/main.py index 041a921d..ed7c8f32 100644 --- a/demos/function_calling/api_server/app/main.py +++ b/demos/function_calling/api_server/app/main.py @@ -66,18 +66,18 @@ async def insurance_claim_details(req: InsuranceClaimDetailsRequest, res: Respon class DefaultTargetRequest(BaseModel): - arch_messages: list + messages: list @app.post("/default_target") async def default_target(req: DefaultTargetRequest, res: Response): - logger.info(f"Received arch_messages: {req.arch_messages}") + logger.info(f"Received arch_messages: {req.messages}") resp = { "choices": [ { "message": { - "role": "assistant", - "content": "hello world from api server", + "role": "user", + "content": "I can help you with weather forecast or insurance claim details", }, "finish_reason": "completed", "index": 0, diff --git a/demos/function_calling/arch_config.yaml b/demos/function_calling/arch_config.yaml index c4b0bd7b..e7448c7e 100644 --- a/demos/function_calling/arch_config.yaml +++ b/demos/function_calling/arch_config.yaml @@ -82,10 +82,10 @@ prompt_targets: name: api_server path: /default_target system_prompt: | - You are a helpful assistant. Use the information that is provided to you. + You are a helpful assistant! Summarize the user's request and provide a helpful response. # if it is set to false arch will send response that it received from this prompt target to the user # if true arch will forward the response to the default LLM - auto_llm_dispatch_on_response: true + auto_llm_dispatch_on_response: false tracing: random_sampling: 100 diff --git a/e2e_tests/common.py b/e2e_tests/common.py index 00ba513c..5449fa89 100644 --- a/e2e_tests/common.py +++ b/e2e_tests/common.py @@ -12,6 +12,7 @@ LLM_GATEWAY_ENDPOINT = os.getenv( def get_data_chunks(stream, n=1): chunks = [] for chunk in stream.iter_lines(): + print(chunk) if chunk: chunk = chunk.decode("utf-8") chunk_data_id = chunk[0:6] diff --git a/e2e_tests/test_prompt_gateway.py b/e2e_tests/test_prompt_gateway.py index 1eae6832..203965cf 100644 --- a/e2e_tests/test_prompt_gateway.py +++ b/e2e_tests/test_prompt_gateway.py @@ -43,7 +43,7 @@ def test_prompt_gateway_param_gathering(stream): response = requests.post(PROMPT_GATEWAY_ENDPOINT, json=body, stream=stream) assert response.status_code == 200 if stream: - chunks = get_data_chunks(response) + chunks = get_data_chunks(response, n=3) assert len(chunks) > 0 response_json = json.loads(chunks[0]) # if its streaming we return tool call and api call in first two chunks @@ -112,3 +112,33 @@ def test_prompt_gateway_param_tool_call(stream): else: response_json = response.json() assert response_json.get("model").startswith("gpt-4o-mini") + + +@pytest.mark.parametrize("stream", [True, False]) +def test_prompt_gateway_default_target(stream): + body = { + "messages": [ + { + "role": "user", + "content": "hello, what can you do for me?", + }, + ], + "stream": stream, + } + response = requests.post(PROMPT_GATEWAY_ENDPOINT, json=body, stream=stream) + assert response.status_code == 200 + if stream: + chunks = get_data_chunks(response, n=3) + assert len(chunks) > 0 + response_json = json.loads(chunks[0]) + assert response_json.get("model").startswith("api_server") + response_json = json.loads(chunks[1]) + choices = response_json.get("choices", []) + assert len(choices) > 0 + content = choices[0]["delta"]["content"] + assert ( + content == "I can help you with weather forecast or insurance claim details" + ) + else: + response_json = response.json() + assert response_json.get("model").startswith("api_server")