From aa6ce2ff7989f0b88cdd353160f6ee88025e35e1 Mon Sep 17 00:00:00 2001 From: Adil Hafeez Date: Thu, 5 Dec 2024 14:56:46 -0800 Subject: [PATCH] more changes --- crates/prompt_gateway/src/stream_context.rs | 111 +++++--------------- 1 file changed, 26 insertions(+), 85 deletions(-) diff --git a/crates/prompt_gateway/src/stream_context.rs b/crates/prompt_gateway/src/stream_context.rs index 501e25a9..e134e07c 100644 --- a/crates/prompt_gateway/src/stream_context.rs +++ b/crates/prompt_gateway/src/stream_context.rs @@ -494,53 +494,6 @@ impl StreamContext { .find(|pt| pt.default.unwrap_or(false)) { debug!("default prompt target found, forwarding request to default prompt target"); - if default_prompt_target.endpoint.is_none() { - info!("default prompt target endpoint not found"); - - let system_prompt = self.get_system_prompt(Some(default_prompt_target.clone())); - - let messages = vec![ - Message { - content: system_prompt, - role: SYSTEM_ROLE.to_string(), - model: Some(ARCH_FC_MODEL_NAME.to_string()), - tool_calls: None, - tool_call_id: None, - }, - Message { - content: self.user_prompt.as_ref().unwrap().content.clone(), - role: ASSISTANT_ROLE.to_string(), - model: Some(ARCH_FC_MODEL_NAME.to_string()), - tool_calls: None, - tool_call_id: None, - }, - ]; - - let chat_completions_request: ChatCompletionsRequest = ChatCompletionsRequest { - model: callout_context.request_body.model, - messages, - tools: None, - stream: callout_context.request_body.stream, - stream_options: callout_context.request_body.stream_options, - metadata: None, - }; - - let llm_request_str = match serde_json::to_string(&chat_completions_request) { - Ok(json_string) => json_string, - Err(e) => { - return self.send_server_error(ServerError::Serialization(e), None); - } - }; - - self.set_http_request_body( - 0, - self.request_body_size, - &llm_request_str.into_bytes(), - ); - - self.resume_http_request(); - return; - } let endpoint = default_prompt_target.endpoint.clone().unwrap(); let upstream_path: String = endpoint.path.unwrap_or(String::from("/")); @@ -594,7 +547,7 @@ impl StreamContext { // if no default prompt target is found and similarity score is low send response to upstream llm // removing tool calls and tool response - let messages = self.construct_llm_messages(&callout_context); + let messages = self.filter_out_arch_messages(&callout_context); let chat_completions_request: ChatCompletionsRequest = ChatCompletionsRequest { model: callout_context.request_body.model, @@ -940,7 +893,7 @@ impl StreamContext { let endpoint = prompt_target.endpoint.unwrap(); let path: String = endpoint.path.unwrap_or(String::from("/")); - // only add params that are of string and number type + // only add params that are of string, number and bool type let url_params = tool_params .iter() .filter(|(_, value)| value.is_number() || value.is_string() || value.is_bool()) @@ -1035,7 +988,7 @@ impl StreamContext { self.tool_call_response.as_ref().unwrap() ); - let mut messages = self.construct_llm_messages(&callout_context); + let mut messages = self.filter_out_arch_messages(&callout_context); let user_message = match messages.pop() { Some(user_message) => user_message, @@ -1092,43 +1045,25 @@ impl StreamContext { self.resume_http_request(); } - fn get_system_prompt(&self, prompt_target: Option) -> Option { - match prompt_target { - None => self.system_prompt.as_ref().clone(), - Some(prompt_target) => prompt_target.system_prompt, - } - } - - fn filter_out_arch_messages(&self, messages: &Vec) -> Vec { - messages - .into_iter() - .filter(|m| { - if m.role == TOOL_ROLE - || m.content.is_none() - || (m.tool_calls.is_some() && !m.tool_calls.as_ref().unwrap().is_empty()) - { - true - } else { - false - } - }) - .cloned() - .collect() - } - - fn construct_llm_messages(&mut self, callout_context: &StreamCallContext) -> Vec { + fn filter_out_arch_messages(&mut self, callout_context: &StreamCallContext) -> Vec { let mut messages: Vec = Vec::new(); - // add system prompt + let system_prompt = match callout_context.prompt_target_name.as_ref() { None => self.system_prompt.as_ref().clone(), Some(prompt_target_name) => { - self.get_system_prompt(self.prompt_targets.get(prompt_target_name).cloned()) + let prompt_system_prompt = self + .prompt_targets + .get(prompt_target_name) + .unwrap() + .clone() + .system_prompt; + match prompt_system_prompt { + None => self.system_prompt.as_ref().clone(), + Some(system_prompt) => Some(system_prompt), + } } }; - - info!("messages 1: {:?}", callout_context.request_body.messages); - if system_prompt.is_some() { let system_prompt_message = Message { role: SYSTEM_ROLE.to_string(), @@ -1140,12 +1075,18 @@ impl StreamContext { messages.push(system_prompt_message); } - info!("messages 2: {:?}", messages); + // don't send tools message and api response to chat gpt + for m in callout_context.request_body.messages.iter() { + // don't send api response and tool calls to upstream LLMs + if m.role == TOOL_ROLE + || m.content.is_none() + || (m.tool_calls.is_some() && !m.tool_calls.as_ref().unwrap().is_empty()) + { + continue; + } + messages.push(m.clone()); + } - messages.append( - &mut self.filter_out_arch_messages(callout_context.request_body.messages.as_ref()), - ); - info!("messages 3: {:?}", messages); messages }