more change

This commit is contained in:
Adil Hafeez 2024-10-30 17:27:01 -07:00
parent d679158f91
commit 6e2beaa968
2 changed files with 81 additions and 40 deletions

View file

@ -499,10 +499,42 @@ impl StreamContext {
Some(StatusCode::BAD_REQUEST),
);
}
}
return;
} else {
// if no default prompt target is found and similarity score is low send response to upstream llm
// removing tool calls and tool response
self.resume_http_request();
return;
let messages = self.filter_out_arch_messages(&callout_context);
let chat_completions_request: ChatCompletionsRequest = ChatCompletionsRequest {
model: callout_context.request_body.model,
messages,
tools: None,
stream: callout_context.request_body.stream,
stream_options: callout_context.request_body.stream_options,
metadata: None,
};
let llm_request_str = match serde_json::to_string(&chat_completions_request) {
Ok(json_string) => json_string,
Err(e) => {
return self.send_server_error(ServerError::Serialization(e), None);
}
};
debug!(
"archgw (low similarity score) => llm request: {}",
llm_request_str
);
self.set_http_request_body(
0,
self.request_body_size,
&llm_request_str.into_bytes(),
);
self.resume_http_request();
return;
}
}
}
@ -873,42 +905,8 @@ impl StreamContext {
"archgw <= api call response: {}",
self.tool_call_response.as_ref().unwrap()
);
let prompt_target_name = callout_context.prompt_target_name.unwrap();
let prompt_target = self
.prompt_targets
.get(&prompt_target_name)
.unwrap()
.clone();
let mut messages: Vec<Message> = Vec::new();
// add system prompt
let system_prompt = match prompt_target.system_prompt.as_ref() {
None => self.system_prompt.as_ref().clone(),
Some(system_prompt) => Some(system_prompt.clone()),
};
if system_prompt.is_some() {
let system_prompt_message = Message {
role: SYSTEM_ROLE.to_string(),
content: system_prompt,
model: None,
tool_calls: None,
tool_call_id: None,
};
messages.push(system_prompt_message);
}
// don't send tools message and api response to chat gpt
for m in callout_context.request_body.messages.iter() {
// don't send api response and tool calls to upstream LLMs
if m.role == TOOL_ROLE
|| m.content.is_none()
|| (m.tool_calls.is_some() && !m.tool_calls.as_ref().unwrap().is_empty())
{
continue;
}
messages.push(m.clone());
}
let mut messages = self.filter_out_arch_messages(&callout_context);
let user_message = match messages.pop() {
Some(user_message) => user_message,
@ -960,6 +958,46 @@ impl StreamContext {
self.resume_http_request();
}
fn filter_out_arch_messages(&mut self, callout_context: &StreamCallContext) -> Vec<Message> {
let mut messages: Vec<Message> = Vec::new();
// add system prompt
let system_prompt = match callout_context.prompt_target_name.as_ref() {
None => self.system_prompt.as_ref().clone(),
Some(prompt_target_name) => {
self.prompt_targets
.get(prompt_target_name)
.unwrap()
.clone()
.system_prompt
}
};
if system_prompt.is_some() {
let system_prompt_message = Message {
role: SYSTEM_ROLE.to_string(),
content: system_prompt,
model: None,
tool_calls: None,
tool_call_id: None,
};
messages.push(system_prompt_message);
}
// don't send tools message and api response to chat gpt
for m in callout_context.request_body.messages.iter() {
// don't send api response and tool calls to upstream LLMs
if m.role == TOOL_ROLE
|| m.content.is_none()
|| (m.tool_calls.is_some() && !m.tool_calls.as_ref().unwrap().is_empty())
{
continue;
}
messages.push(m.clone());
}
messages
}
pub fn arch_guard_handler(&mut self, body: Vec<u8>, callout_context: StreamCallContext) {
let prompt_guard_resp: PromptGuardResponse = serde_json::from_slice(&body).unwrap();
debug!(