mirror of
https://github.com/katanemo/plano.git
synced 2026-07-02 15:51:02 +02:00
more changes
This commit is contained in:
parent
63d7d91267
commit
aa6ce2ff79
1 changed files with 26 additions and 85 deletions
|
|
@ -494,53 +494,6 @@ impl StreamContext {
|
||||||
.find(|pt| pt.default.unwrap_or(false))
|
.find(|pt| pt.default.unwrap_or(false))
|
||||||
{
|
{
|
||||||
debug!("default prompt target found, forwarding request to default prompt target");
|
debug!("default prompt target found, forwarding request to default prompt target");
|
||||||
if default_prompt_target.endpoint.is_none() {
|
|
||||||
info!("default prompt target endpoint not found");
|
|
||||||
|
|
||||||
let system_prompt = self.get_system_prompt(Some(default_prompt_target.clone()));
|
|
||||||
|
|
||||||
let messages = vec![
|
|
||||||
Message {
|
|
||||||
content: system_prompt,
|
|
||||||
role: SYSTEM_ROLE.to_string(),
|
|
||||||
model: Some(ARCH_FC_MODEL_NAME.to_string()),
|
|
||||||
tool_calls: None,
|
|
||||||
tool_call_id: None,
|
|
||||||
},
|
|
||||||
Message {
|
|
||||||
content: self.user_prompt.as_ref().unwrap().content.clone(),
|
|
||||||
role: ASSISTANT_ROLE.to_string(),
|
|
||||||
model: Some(ARCH_FC_MODEL_NAME.to_string()),
|
|
||||||
tool_calls: None,
|
|
||||||
tool_call_id: None,
|
|
||||||
},
|
|
||||||
];
|
|
||||||
|
|
||||||
let chat_completions_request: ChatCompletionsRequest = ChatCompletionsRequest {
|
|
||||||
model: callout_context.request_body.model,
|
|
||||||
messages,
|
|
||||||
tools: None,
|
|
||||||
stream: callout_context.request_body.stream,
|
|
||||||
stream_options: callout_context.request_body.stream_options,
|
|
||||||
metadata: None,
|
|
||||||
};
|
|
||||||
|
|
||||||
let llm_request_str = match serde_json::to_string(&chat_completions_request) {
|
|
||||||
Ok(json_string) => json_string,
|
|
||||||
Err(e) => {
|
|
||||||
return self.send_server_error(ServerError::Serialization(e), None);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
self.set_http_request_body(
|
|
||||||
0,
|
|
||||||
self.request_body_size,
|
|
||||||
&llm_request_str.into_bytes(),
|
|
||||||
);
|
|
||||||
|
|
||||||
self.resume_http_request();
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
let endpoint = default_prompt_target.endpoint.clone().unwrap();
|
let endpoint = default_prompt_target.endpoint.clone().unwrap();
|
||||||
let upstream_path: String = endpoint.path.unwrap_or(String::from("/"));
|
let upstream_path: String = endpoint.path.unwrap_or(String::from("/"));
|
||||||
|
|
||||||
|
|
@ -594,7 +547,7 @@ impl StreamContext {
|
||||||
// if no default prompt target is found and similarity score is low send response to upstream llm
|
// if no default prompt target is found and similarity score is low send response to upstream llm
|
||||||
// removing tool calls and tool response
|
// removing tool calls and tool response
|
||||||
|
|
||||||
let messages = self.construct_llm_messages(&callout_context);
|
let messages = self.filter_out_arch_messages(&callout_context);
|
||||||
|
|
||||||
let chat_completions_request: ChatCompletionsRequest = ChatCompletionsRequest {
|
let chat_completions_request: ChatCompletionsRequest = ChatCompletionsRequest {
|
||||||
model: callout_context.request_body.model,
|
model: callout_context.request_body.model,
|
||||||
|
|
@ -940,7 +893,7 @@ impl StreamContext {
|
||||||
let endpoint = prompt_target.endpoint.unwrap();
|
let endpoint = prompt_target.endpoint.unwrap();
|
||||||
let path: String = endpoint.path.unwrap_or(String::from("/"));
|
let path: String = endpoint.path.unwrap_or(String::from("/"));
|
||||||
|
|
||||||
// only add params that are of string and number type
|
// only add params that are of string, number and bool type
|
||||||
let url_params = tool_params
|
let url_params = tool_params
|
||||||
.iter()
|
.iter()
|
||||||
.filter(|(_, value)| value.is_number() || value.is_string() || value.is_bool())
|
.filter(|(_, value)| value.is_number() || value.is_string() || value.is_bool())
|
||||||
|
|
@ -1035,7 +988,7 @@ impl StreamContext {
|
||||||
self.tool_call_response.as_ref().unwrap()
|
self.tool_call_response.as_ref().unwrap()
|
||||||
);
|
);
|
||||||
|
|
||||||
let mut messages = self.construct_llm_messages(&callout_context);
|
let mut messages = self.filter_out_arch_messages(&callout_context);
|
||||||
|
|
||||||
let user_message = match messages.pop() {
|
let user_message = match messages.pop() {
|
||||||
Some(user_message) => user_message,
|
Some(user_message) => user_message,
|
||||||
|
|
@ -1092,43 +1045,25 @@ impl StreamContext {
|
||||||
self.resume_http_request();
|
self.resume_http_request();
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_system_prompt(&self, prompt_target: Option<PromptTarget>) -> Option<String> {
|
fn filter_out_arch_messages(&mut self, callout_context: &StreamCallContext) -> Vec<Message> {
|
||||||
match prompt_target {
|
|
||||||
None => self.system_prompt.as_ref().clone(),
|
|
||||||
Some(prompt_target) => prompt_target.system_prompt,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn filter_out_arch_messages(&self, messages: &Vec<Message>) -> Vec<Message> {
|
|
||||||
messages
|
|
||||||
.into_iter()
|
|
||||||
.filter(|m| {
|
|
||||||
if m.role == TOOL_ROLE
|
|
||||||
|| m.content.is_none()
|
|
||||||
|| (m.tool_calls.is_some() && !m.tool_calls.as_ref().unwrap().is_empty())
|
|
||||||
{
|
|
||||||
true
|
|
||||||
} else {
|
|
||||||
false
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.cloned()
|
|
||||||
.collect()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn construct_llm_messages(&mut self, callout_context: &StreamCallContext) -> Vec<Message> {
|
|
||||||
let mut messages: Vec<Message> = Vec::new();
|
let mut messages: Vec<Message> = Vec::new();
|
||||||
|
|
||||||
// add system prompt
|
// add system prompt
|
||||||
|
|
||||||
let system_prompt = match callout_context.prompt_target_name.as_ref() {
|
let system_prompt = match callout_context.prompt_target_name.as_ref() {
|
||||||
None => self.system_prompt.as_ref().clone(),
|
None => self.system_prompt.as_ref().clone(),
|
||||||
Some(prompt_target_name) => {
|
Some(prompt_target_name) => {
|
||||||
self.get_system_prompt(self.prompt_targets.get(prompt_target_name).cloned())
|
let prompt_system_prompt = self
|
||||||
|
.prompt_targets
|
||||||
|
.get(prompt_target_name)
|
||||||
|
.unwrap()
|
||||||
|
.clone()
|
||||||
|
.system_prompt;
|
||||||
|
match prompt_system_prompt {
|
||||||
|
None => self.system_prompt.as_ref().clone(),
|
||||||
|
Some(system_prompt) => Some(system_prompt),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
info!("messages 1: {:?}", callout_context.request_body.messages);
|
|
||||||
|
|
||||||
if system_prompt.is_some() {
|
if system_prompt.is_some() {
|
||||||
let system_prompt_message = Message {
|
let system_prompt_message = Message {
|
||||||
role: SYSTEM_ROLE.to_string(),
|
role: SYSTEM_ROLE.to_string(),
|
||||||
|
|
@ -1140,12 +1075,18 @@ impl StreamContext {
|
||||||
messages.push(system_prompt_message);
|
messages.push(system_prompt_message);
|
||||||
}
|
}
|
||||||
|
|
||||||
info!("messages 2: {:?}", messages);
|
// don't send tools message and api response to chat gpt
|
||||||
|
for m in callout_context.request_body.messages.iter() {
|
||||||
|
// don't send api response and tool calls to upstream LLMs
|
||||||
|
if m.role == TOOL_ROLE
|
||||||
|
|| m.content.is_none()
|
||||||
|
|| (m.tool_calls.is_some() && !m.tool_calls.as_ref().unwrap().is_empty())
|
||||||
|
{
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
messages.push(m.clone());
|
||||||
|
}
|
||||||
|
|
||||||
messages.append(
|
|
||||||
&mut self.filter_out_arch_messages(callout_context.request_body.messages.as_ref()),
|
|
||||||
);
|
|
||||||
info!("messages 3: {:?}", messages);
|
|
||||||
messages
|
messages
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue