more change

This commit is contained in:
Adil Hafeez 2024-10-30 17:27:01 -07:00
parent d679158f91
commit 6e2beaa968
2 changed files with 81 additions and 40 deletions

View file

@ -499,10 +499,42 @@ impl StreamContext {
Some(StatusCode::BAD_REQUEST), Some(StatusCode::BAD_REQUEST),
); );
} }
} return;
} else {
// if no default prompt target is found and similarity score is low send response to upstream llm
// removing tool calls and tool response
self.resume_http_request(); let messages = self.filter_out_arch_messages(&callout_context);
return;
let chat_completions_request: ChatCompletionsRequest = ChatCompletionsRequest {
model: callout_context.request_body.model,
messages,
tools: None,
stream: callout_context.request_body.stream,
stream_options: callout_context.request_body.stream_options,
metadata: None,
};
let llm_request_str = match serde_json::to_string(&chat_completions_request) {
Ok(json_string) => json_string,
Err(e) => {
return self.send_server_error(ServerError::Serialization(e), None);
}
};
debug!(
"archgw (low similarity score) => llm request: {}",
llm_request_str
);
self.set_http_request_body(
0,
self.request_body_size,
&llm_request_str.into_bytes(),
);
self.resume_http_request();
return;
}
} }
} }
@ -873,42 +905,8 @@ impl StreamContext {
"archgw <= api call response: {}", "archgw <= api call response: {}",
self.tool_call_response.as_ref().unwrap() self.tool_call_response.as_ref().unwrap()
); );
let prompt_target_name = callout_context.prompt_target_name.unwrap();
let prompt_target = self
.prompt_targets
.get(&prompt_target_name)
.unwrap()
.clone();
let mut messages: Vec<Message> = Vec::new(); let mut messages = self.filter_out_arch_messages(&callout_context);
// add system prompt
let system_prompt = match prompt_target.system_prompt.as_ref() {
None => self.system_prompt.as_ref().clone(),
Some(system_prompt) => Some(system_prompt.clone()),
};
if system_prompt.is_some() {
let system_prompt_message = Message {
role: SYSTEM_ROLE.to_string(),
content: system_prompt,
model: None,
tool_calls: None,
tool_call_id: None,
};
messages.push(system_prompt_message);
}
// don't send tools message and api response to chat gpt
for m in callout_context.request_body.messages.iter() {
// don't send api response and tool calls to upstream LLMs
if m.role == TOOL_ROLE
|| m.content.is_none()
|| (m.tool_calls.is_some() && !m.tool_calls.as_ref().unwrap().is_empty())
{
continue;
}
messages.push(m.clone());
}
let user_message = match messages.pop() { let user_message = match messages.pop() {
Some(user_message) => user_message, Some(user_message) => user_message,
@ -960,6 +958,46 @@ impl StreamContext {
self.resume_http_request(); self.resume_http_request();
} }
fn filter_out_arch_messages(&mut self, callout_context: &StreamCallContext) -> Vec<Message> {
let mut messages: Vec<Message> = Vec::new();
// add system prompt
let system_prompt = match callout_context.prompt_target_name.as_ref() {
None => self.system_prompt.as_ref().clone(),
Some(prompt_target_name) => {
self.prompt_targets
.get(prompt_target_name)
.unwrap()
.clone()
.system_prompt
}
};
if system_prompt.is_some() {
let system_prompt_message = Message {
role: SYSTEM_ROLE.to_string(),
content: system_prompt,
model: None,
tool_calls: None,
tool_call_id: None,
};
messages.push(system_prompt_message);
}
// don't send tools message and api response to chat gpt
for m in callout_context.request_body.messages.iter() {
// don't send api response and tool calls to upstream LLMs
if m.role == TOOL_ROLE
|| m.content.is_none()
|| (m.tool_calls.is_some() && !m.tool_calls.as_ref().unwrap().is_empty())
{
continue;
}
messages.push(m.clone());
}
messages
}
pub fn arch_guard_handler(&mut self, body: Vec<u8>, callout_context: StreamCallContext) { pub fn arch_guard_handler(&mut self, body: Vec<u8>, callout_context: StreamCallContext) {
let prompt_guard_resp: PromptGuardResponse = serde_json::from_slice(&body).unwrap(); let prompt_guard_resp: PromptGuardResponse = serde_json::from_slice(&body).unwrap();
debug!( debug!(

View file

@ -9,7 +9,7 @@ llm_providers:
- name: OpenAI - name: OpenAI
provider: openai provider: openai
access_key: $OPENAI_API_KEY access_key: $OPENAI_API_KEY
model: gpt-4o-mini model: gpt-4o
default: true default: true
# Arch creates a round-robin load balancing between different endpoints, managed via the cluster subsystem. # Arch creates a round-robin load balancing between different endpoints, managed via the cluster subsystem.
@ -24,7 +24,10 @@ endpoints:
# default system prompt used by all prompt targets # default system prompt used by all prompt targets
system_prompt: | system_prompt: |
You are a Workforce assistant that helps on workforce planning and HR decision makers with reporting and workfoce planning. NOTHING ELSE. When you get data in json format, offer some summary but don't be too verbose. You are a Workforce assistant that helps on workforce planning and HR decision makers with reporting and workforce planning. Use following rules when responding,
- when you get data in json format, offer some summary but don't be too verbose
- be concise and to the point
- if you don't have data, say so and offer to help with something else
prompt_targets: prompt_targets:
- name: workforce - name: workforce