more changes

This commit is contained in:
Adil Hafeez 2024-12-05 14:56:46 -08:00
parent 63d7d91267
commit aa6ce2ff79

View file

@ -494,53 +494,6 @@ impl StreamContext {
.find(|pt| pt.default.unwrap_or(false))
{
debug!("default prompt target found, forwarding request to default prompt target");
if default_prompt_target.endpoint.is_none() {
info!("default prompt target endpoint not found");
let system_prompt = self.get_system_prompt(Some(default_prompt_target.clone()));
let messages = vec![
Message {
content: system_prompt,
role: SYSTEM_ROLE.to_string(),
model: Some(ARCH_FC_MODEL_NAME.to_string()),
tool_calls: None,
tool_call_id: None,
},
Message {
content: self.user_prompt.as_ref().unwrap().content.clone(),
role: ASSISTANT_ROLE.to_string(),
model: Some(ARCH_FC_MODEL_NAME.to_string()),
tool_calls: None,
tool_call_id: None,
},
];
let chat_completions_request: ChatCompletionsRequest = ChatCompletionsRequest {
model: callout_context.request_body.model,
messages,
tools: None,
stream: callout_context.request_body.stream,
stream_options: callout_context.request_body.stream_options,
metadata: None,
};
let llm_request_str = match serde_json::to_string(&chat_completions_request) {
Ok(json_string) => json_string,
Err(e) => {
return self.send_server_error(ServerError::Serialization(e), None);
}
};
self.set_http_request_body(
0,
self.request_body_size,
&llm_request_str.into_bytes(),
);
self.resume_http_request();
return;
}
let endpoint = default_prompt_target.endpoint.clone().unwrap();
let upstream_path: String = endpoint.path.unwrap_or(String::from("/"));
@ -594,7 +547,7 @@ impl StreamContext {
// if no default prompt target is found and similarity score is low send response to upstream llm
// removing tool calls and tool response
let messages = self.construct_llm_messages(&callout_context);
let messages = self.filter_out_arch_messages(&callout_context);
let chat_completions_request: ChatCompletionsRequest = ChatCompletionsRequest {
model: callout_context.request_body.model,
@ -940,7 +893,7 @@ impl StreamContext {
let endpoint = prompt_target.endpoint.unwrap();
let path: String = endpoint.path.unwrap_or(String::from("/"));
// only add params that are of string and number type
// only add params that are of string, number and bool type
let url_params = tool_params
.iter()
.filter(|(_, value)| value.is_number() || value.is_string() || value.is_bool())
@ -1035,7 +988,7 @@ impl StreamContext {
self.tool_call_response.as_ref().unwrap()
);
let mut messages = self.construct_llm_messages(&callout_context);
let mut messages = self.filter_out_arch_messages(&callout_context);
let user_message = match messages.pop() {
Some(user_message) => user_message,
@ -1092,43 +1045,25 @@ impl StreamContext {
self.resume_http_request();
}
fn get_system_prompt(&self, prompt_target: Option<PromptTarget>) -> Option<String> {
match prompt_target {
None => self.system_prompt.as_ref().clone(),
Some(prompt_target) => prompt_target.system_prompt,
}
}
fn filter_out_arch_messages(&self, messages: &Vec<Message>) -> Vec<Message> {
messages
.into_iter()
.filter(|m| {
if m.role == TOOL_ROLE
|| m.content.is_none()
|| (m.tool_calls.is_some() && !m.tool_calls.as_ref().unwrap().is_empty())
{
true
} else {
false
}
})
.cloned()
.collect()
}
fn construct_llm_messages(&mut self, callout_context: &StreamCallContext) -> Vec<Message> {
fn filter_out_arch_messages(&mut self, callout_context: &StreamCallContext) -> Vec<Message> {
let mut messages: Vec<Message> = Vec::new();
// add system prompt
let system_prompt = match callout_context.prompt_target_name.as_ref() {
None => self.system_prompt.as_ref().clone(),
Some(prompt_target_name) => {
self.get_system_prompt(self.prompt_targets.get(prompt_target_name).cloned())
let prompt_system_prompt = self
.prompt_targets
.get(prompt_target_name)
.unwrap()
.clone()
.system_prompt;
match prompt_system_prompt {
None => self.system_prompt.as_ref().clone(),
Some(system_prompt) => Some(system_prompt),
}
}
};
info!("messages 1: {:?}", callout_context.request_body.messages);
if system_prompt.is_some() {
let system_prompt_message = Message {
role: SYSTEM_ROLE.to_string(),
@ -1140,12 +1075,18 @@ impl StreamContext {
messages.push(system_prompt_message);
}
info!("messages 2: {:?}", messages);
// don't send tools message and api response to chat gpt
for m in callout_context.request_body.messages.iter() {
// don't send api response and tool calls to upstream LLMs
if m.role == TOOL_ROLE
|| m.content.is_none()
|| (m.tool_calls.is_some() && !m.tool_calls.as_ref().unwrap().is_empty())
{
continue;
}
messages.push(m.clone());
}
messages.append(
&mut self.filter_out_arch_messages(callout_context.request_body.messages.as_ref()),
);
info!("messages 3: {:?}", messages);
messages
}