mirror of
https://github.com/katanemo/plano.git
synced 2026-05-03 21:02:56 +02:00
add system prompt (#138)
This commit is contained in:
parent
c1cfbcd44d
commit
422efd3887
2 changed files with 47 additions and 23 deletions
|
|
@ -80,11 +80,14 @@ pub enum ServerError {
|
|||
Jailbreak(String),
|
||||
#[error("{why}")]
|
||||
BadRequest { why: String },
|
||||
#[error("{why}")]
|
||||
NoMessagesFound { why: String },
|
||||
}
|
||||
|
||||
pub struct StreamContext {
|
||||
context_id: u32,
|
||||
metrics: Rc<WasmMetrics>,
|
||||
system_prompt: Rc<Option<String>>,
|
||||
prompt_targets: Rc<HashMap<String, PromptTarget>>,
|
||||
embeddings_store: Rc<EmbeddingsStore>,
|
||||
overrides: Rc<Option<Overrides>>,
|
||||
|
|
@ -108,6 +111,7 @@ impl StreamContext {
|
|||
pub fn new(
|
||||
context_id: u32,
|
||||
metrics: Rc<WasmMetrics>,
|
||||
system_prompt: Rc<Option<String>>,
|
||||
prompt_targets: Rc<HashMap<String, PromptTarget>>,
|
||||
prompt_guards: Rc<PromptGuards>,
|
||||
overrides: Rc<Option<Overrides>>,
|
||||
|
|
@ -117,6 +121,7 @@ impl StreamContext {
|
|||
StreamContext {
|
||||
context_id,
|
||||
metrics,
|
||||
system_prompt,
|
||||
prompt_targets,
|
||||
embeddings_store,
|
||||
callouts: RefCell::new(HashMap::new()),
|
||||
|
|
@ -633,9 +638,12 @@ impl StreamContext {
|
|||
} else {
|
||||
warn!("http status code not found in api response");
|
||||
}
|
||||
let body_str: String = String::from_utf8(body).unwrap();
|
||||
self.tool_call_response = Some(body_str.clone());
|
||||
debug!("arch <= app response body: {}", body_str);
|
||||
let app_function_call_response_str: String = String::from_utf8(body).unwrap();
|
||||
self.tool_call_response = Some(app_function_call_response_str.clone());
|
||||
debug!(
|
||||
"arch <= app response body: {}",
|
||||
app_function_call_response_str
|
||||
);
|
||||
let prompt_target_name = callout_context.prompt_target_name.unwrap();
|
||||
let prompt_target = self
|
||||
.prompt_targets
|
||||
|
|
@ -644,36 +652,48 @@ impl StreamContext {
|
|||
.clone();
|
||||
|
||||
let mut messages: Vec<Message> = callout_context.request_body.messages.clone();
|
||||
let user_message = match messages.pop() {
|
||||
Some(user_message) => user_message,
|
||||
None => {
|
||||
return self.send_server_error(
|
||||
ServerError::NoMessagesFound {
|
||||
why: "no user messages found".to_string(),
|
||||
},
|
||||
None,
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
// add system prompt
|
||||
match prompt_target.system_prompt.as_ref() {
|
||||
None => {}
|
||||
Some(system_prompt) => {
|
||||
let system_prompt_message = Message {
|
||||
role: SYSTEM_ROLE.to_string(),
|
||||
content: Some(system_prompt.clone()),
|
||||
model: None,
|
||||
tool_calls: None,
|
||||
};
|
||||
messages.push(system_prompt_message);
|
||||
}
|
||||
}
|
||||
let system_prompt = match prompt_target.system_prompt.as_ref() {
|
||||
None => match self.system_prompt.as_ref() {
|
||||
None => None,
|
||||
Some(system_prompt) => Some(system_prompt.clone()),
|
||||
},
|
||||
Some(system_prompt) => Some(system_prompt.clone()),
|
||||
};
|
||||
|
||||
// add data from function call response
|
||||
messages.push({
|
||||
Message {
|
||||
role: USER_ROLE.to_string(),
|
||||
content: Some(body_str),
|
||||
if system_prompt.is_some() {
|
||||
let system_prompt_message = Message {
|
||||
role: SYSTEM_ROLE.to_string(),
|
||||
content: system_prompt,
|
||||
model: None,
|
||||
tool_calls: None,
|
||||
}
|
||||
});
|
||||
};
|
||||
messages.push(system_prompt_message);
|
||||
}
|
||||
|
||||
let final_prompt = format!(
|
||||
"{}\nhere is context: {}",
|
||||
user_message.content.unwrap(),
|
||||
app_function_call_response_str
|
||||
);
|
||||
|
||||
// add original user prompt
|
||||
messages.push({
|
||||
Message {
|
||||
role: USER_ROLE.to_string(),
|
||||
content: Some(callout_context.user_message.unwrap()),
|
||||
content: Some(final_prompt),
|
||||
model: None,
|
||||
tool_calls: None,
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue