diff --git a/crates/prompt_gateway/src/stream_context.rs b/crates/prompt_gateway/src/stream_context.rs index a2fe0b2e..3a7e9a9b 100644 --- a/crates/prompt_gateway/src/stream_context.rs +++ b/crates/prompt_gateway/src/stream_context.rs @@ -499,10 +499,42 @@ impl StreamContext { Some(StatusCode::BAD_REQUEST), ); } - } + return; + } else { + // if no default prompt target is found and similarity score is low send response to upstream llm + // removing tool calls and tool response - self.resume_http_request(); - return; + let messages = self.filter_out_arch_messages(&callout_context); + + let chat_completions_request: ChatCompletionsRequest = ChatCompletionsRequest { + model: callout_context.request_body.model, + messages, + tools: None, + stream: callout_context.request_body.stream, + stream_options: callout_context.request_body.stream_options, + metadata: None, + }; + + let llm_request_str = match serde_json::to_string(&chat_completions_request) { + Ok(json_string) => json_string, + Err(e) => { + return self.send_server_error(ServerError::Serialization(e), None); + } + }; + debug!( + "archgw (low similarity score) => llm request: {}", + llm_request_str + ); + + self.set_http_request_body( + 0, + self.request_body_size, + &llm_request_str.into_bytes(), + ); + + self.resume_http_request(); + return; + } } } @@ -873,42 +905,8 @@ impl StreamContext { "archgw <= api call response: {}", self.tool_call_response.as_ref().unwrap() ); - let prompt_target_name = callout_context.prompt_target_name.unwrap(); - let prompt_target = self - .prompt_targets - .get(&prompt_target_name) - .unwrap() - .clone(); - let mut messages: Vec = Vec::new(); - - // add system prompt - let system_prompt = match prompt_target.system_prompt.as_ref() { - None => self.system_prompt.as_ref().clone(), - Some(system_prompt) => Some(system_prompt.clone()), - }; - if system_prompt.is_some() { - let system_prompt_message = Message { - role: SYSTEM_ROLE.to_string(), - content: system_prompt, - model: None, - tool_calls: None, - tool_call_id: None, - }; - messages.push(system_prompt_message); - } - - // don't send tools message and api response to chat gpt - for m in callout_context.request_body.messages.iter() { - // don't send api response and tool calls to upstream LLMs - if m.role == TOOL_ROLE - || m.content.is_none() - || (m.tool_calls.is_some() && !m.tool_calls.as_ref().unwrap().is_empty()) - { - continue; - } - messages.push(m.clone()); - } + let mut messages = self.filter_out_arch_messages(&callout_context); let user_message = match messages.pop() { Some(user_message) => user_message, @@ -960,6 +958,46 @@ impl StreamContext { self.resume_http_request(); } + fn filter_out_arch_messages(&mut self, callout_context: &StreamCallContext) -> Vec { + + let mut messages: Vec = Vec::new(); + // add system prompt + let system_prompt = match callout_context.prompt_target_name.as_ref() { + None => self.system_prompt.as_ref().clone(), + Some(prompt_target_name) => { + self.prompt_targets + .get(prompt_target_name) + .unwrap() + .clone() + .system_prompt + } + }; + if system_prompt.is_some() { + let system_prompt_message = Message { + role: SYSTEM_ROLE.to_string(), + content: system_prompt, + model: None, + tool_calls: None, + tool_call_id: None, + }; + messages.push(system_prompt_message); + } + + // don't send tools message and api response to chat gpt + for m in callout_context.request_body.messages.iter() { + // don't send api response and tool calls to upstream LLMs + if m.role == TOOL_ROLE + || m.content.is_none() + || (m.tool_calls.is_some() && !m.tool_calls.as_ref().unwrap().is_empty()) + { + continue; + } + messages.push(m.clone()); + } + + messages + } + pub fn arch_guard_handler(&mut self, body: Vec, callout_context: StreamCallContext) { let prompt_guard_resp: PromptGuardResponse = serde_json::from_slice(&body).unwrap(); debug!( diff --git a/demos/hr_agent/arch_config.yaml b/demos/hr_agent/arch_config.yaml index 1e93dfa1..2b56cb6e 100644 --- a/demos/hr_agent/arch_config.yaml +++ b/demos/hr_agent/arch_config.yaml @@ -9,7 +9,7 @@ llm_providers: - name: OpenAI provider: openai access_key: $OPENAI_API_KEY - model: gpt-4o-mini + model: gpt-4o default: true # Arch creates a round-robin load balancing between different endpoints, managed via the cluster subsystem. @@ -24,7 +24,10 @@ endpoints: # default system prompt used by all prompt targets system_prompt: | - You are a Workforce assistant that helps on workforce planning and HR decision makers with reporting and workfoce planning. NOTHING ELSE. When you get data in json format, offer some summary but don't be too verbose. + You are a Workforce assistant that helps on workforce planning and HR decision makers with reporting and workforce planning. Use following rules when responding, + - when you get data in json format, offer some summary but don't be too verbose + - be concise and to the point + - if you don't have data, say so and offer to help with something else prompt_targets: - name: workforce