mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
more changes
This commit is contained in:
parent
63d7d91267
commit
aa6ce2ff79
1 changed files with 26 additions and 85 deletions
|
|
@ -494,53 +494,6 @@ impl StreamContext {
|
|||
.find(|pt| pt.default.unwrap_or(false))
|
||||
{
|
||||
debug!("default prompt target found, forwarding request to default prompt target");
|
||||
if default_prompt_target.endpoint.is_none() {
|
||||
info!("default prompt target endpoint not found");
|
||||
|
||||
let system_prompt = self.get_system_prompt(Some(default_prompt_target.clone()));
|
||||
|
||||
let messages = vec![
|
||||
Message {
|
||||
content: system_prompt,
|
||||
role: SYSTEM_ROLE.to_string(),
|
||||
model: Some(ARCH_FC_MODEL_NAME.to_string()),
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
Message {
|
||||
content: self.user_prompt.as_ref().unwrap().content.clone(),
|
||||
role: ASSISTANT_ROLE.to_string(),
|
||||
model: Some(ARCH_FC_MODEL_NAME.to_string()),
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
];
|
||||
|
||||
let chat_completions_request: ChatCompletionsRequest = ChatCompletionsRequest {
|
||||
model: callout_context.request_body.model,
|
||||
messages,
|
||||
tools: None,
|
||||
stream: callout_context.request_body.stream,
|
||||
stream_options: callout_context.request_body.stream_options,
|
||||
metadata: None,
|
||||
};
|
||||
|
||||
let llm_request_str = match serde_json::to_string(&chat_completions_request) {
|
||||
Ok(json_string) => json_string,
|
||||
Err(e) => {
|
||||
return self.send_server_error(ServerError::Serialization(e), None);
|
||||
}
|
||||
};
|
||||
|
||||
self.set_http_request_body(
|
||||
0,
|
||||
self.request_body_size,
|
||||
&llm_request_str.into_bytes(),
|
||||
);
|
||||
|
||||
self.resume_http_request();
|
||||
return;
|
||||
}
|
||||
let endpoint = default_prompt_target.endpoint.clone().unwrap();
|
||||
let upstream_path: String = endpoint.path.unwrap_or(String::from("/"));
|
||||
|
||||
|
|
@ -594,7 +547,7 @@ impl StreamContext {
|
|||
// if no default prompt target is found and similarity score is low send response to upstream llm
|
||||
// removing tool calls and tool response
|
||||
|
||||
let messages = self.construct_llm_messages(&callout_context);
|
||||
let messages = self.filter_out_arch_messages(&callout_context);
|
||||
|
||||
let chat_completions_request: ChatCompletionsRequest = ChatCompletionsRequest {
|
||||
model: callout_context.request_body.model,
|
||||
|
|
@ -940,7 +893,7 @@ impl StreamContext {
|
|||
let endpoint = prompt_target.endpoint.unwrap();
|
||||
let path: String = endpoint.path.unwrap_or(String::from("/"));
|
||||
|
||||
// only add params that are of string and number type
|
||||
// only add params that are of string, number and bool type
|
||||
let url_params = tool_params
|
||||
.iter()
|
||||
.filter(|(_, value)| value.is_number() || value.is_string() || value.is_bool())
|
||||
|
|
@ -1035,7 +988,7 @@ impl StreamContext {
|
|||
self.tool_call_response.as_ref().unwrap()
|
||||
);
|
||||
|
||||
let mut messages = self.construct_llm_messages(&callout_context);
|
||||
let mut messages = self.filter_out_arch_messages(&callout_context);
|
||||
|
||||
let user_message = match messages.pop() {
|
||||
Some(user_message) => user_message,
|
||||
|
|
@ -1092,43 +1045,25 @@ impl StreamContext {
|
|||
self.resume_http_request();
|
||||
}
|
||||
|
||||
fn get_system_prompt(&self, prompt_target: Option<PromptTarget>) -> Option<String> {
|
||||
match prompt_target {
|
||||
None => self.system_prompt.as_ref().clone(),
|
||||
Some(prompt_target) => prompt_target.system_prompt,
|
||||
}
|
||||
}
|
||||
|
||||
fn filter_out_arch_messages(&self, messages: &Vec<Message>) -> Vec<Message> {
|
||||
messages
|
||||
.into_iter()
|
||||
.filter(|m| {
|
||||
if m.role == TOOL_ROLE
|
||||
|| m.content.is_none()
|
||||
|| (m.tool_calls.is_some() && !m.tool_calls.as_ref().unwrap().is_empty())
|
||||
{
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
})
|
||||
.cloned()
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn construct_llm_messages(&mut self, callout_context: &StreamCallContext) -> Vec<Message> {
|
||||
fn filter_out_arch_messages(&mut self, callout_context: &StreamCallContext) -> Vec<Message> {
|
||||
let mut messages: Vec<Message> = Vec::new();
|
||||
|
||||
// add system prompt
|
||||
|
||||
let system_prompt = match callout_context.prompt_target_name.as_ref() {
|
||||
None => self.system_prompt.as_ref().clone(),
|
||||
Some(prompt_target_name) => {
|
||||
self.get_system_prompt(self.prompt_targets.get(prompt_target_name).cloned())
|
||||
let prompt_system_prompt = self
|
||||
.prompt_targets
|
||||
.get(prompt_target_name)
|
||||
.unwrap()
|
||||
.clone()
|
||||
.system_prompt;
|
||||
match prompt_system_prompt {
|
||||
None => self.system_prompt.as_ref().clone(),
|
||||
Some(system_prompt) => Some(system_prompt),
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
info!("messages 1: {:?}", callout_context.request_body.messages);
|
||||
|
||||
if system_prompt.is_some() {
|
||||
let system_prompt_message = Message {
|
||||
role: SYSTEM_ROLE.to_string(),
|
||||
|
|
@ -1140,12 +1075,18 @@ impl StreamContext {
|
|||
messages.push(system_prompt_message);
|
||||
}
|
||||
|
||||
info!("messages 2: {:?}", messages);
|
||||
// don't send tools message and api response to chat gpt
|
||||
for m in callout_context.request_body.messages.iter() {
|
||||
// don't send api response and tool calls to upstream LLMs
|
||||
if m.role == TOOL_ROLE
|
||||
|| m.content.is_none()
|
||||
|| (m.tool_calls.is_some() && !m.tool_calls.as_ref().unwrap().is_empty())
|
||||
{
|
||||
continue;
|
||||
}
|
||||
messages.push(m.clone());
|
||||
}
|
||||
|
||||
messages.append(
|
||||
&mut self.filter_out_arch_messages(callout_context.request_body.messages.as_ref()),
|
||||
);
|
||||
info!("messages 3: {:?}", messages);
|
||||
messages
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue