add system prompt (#138)

This commit is contained in:
Adil Hafeez 2024-10-07 17:25:37 -07:00 committed by GitHub
parent c1cfbcd44d
commit 422efd3887
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 47 additions and 23 deletions

View file

@ -80,11 +80,14 @@ pub enum ServerError {
Jailbreak(String),
#[error("{why}")]
BadRequest { why: String },
#[error("{why}")]
NoMessagesFound { why: String },
}
pub struct StreamContext {
context_id: u32,
metrics: Rc<WasmMetrics>,
system_prompt: Rc<Option<String>>,
prompt_targets: Rc<HashMap<String, PromptTarget>>,
embeddings_store: Rc<EmbeddingsStore>,
overrides: Rc<Option<Overrides>>,
@ -108,6 +111,7 @@ impl StreamContext {
pub fn new(
context_id: u32,
metrics: Rc<WasmMetrics>,
system_prompt: Rc<Option<String>>,
prompt_targets: Rc<HashMap<String, PromptTarget>>,
prompt_guards: Rc<PromptGuards>,
overrides: Rc<Option<Overrides>>,
@ -117,6 +121,7 @@ impl StreamContext {
StreamContext {
context_id,
metrics,
system_prompt,
prompt_targets,
embeddings_store,
callouts: RefCell::new(HashMap::new()),
@ -633,9 +638,12 @@ impl StreamContext {
} else {
warn!("http status code not found in api response");
}
let body_str: String = String::from_utf8(body).unwrap();
self.tool_call_response = Some(body_str.clone());
debug!("arch <= app response body: {}", body_str);
let app_function_call_response_str: String = String::from_utf8(body).unwrap();
self.tool_call_response = Some(app_function_call_response_str.clone());
debug!(
"arch <= app response body: {}",
app_function_call_response_str
);
let prompt_target_name = callout_context.prompt_target_name.unwrap();
let prompt_target = self
.prompt_targets
@ -644,36 +652,48 @@ impl StreamContext {
.clone();
let mut messages: Vec<Message> = callout_context.request_body.messages.clone();
let user_message = match messages.pop() {
Some(user_message) => user_message,
None => {
return self.send_server_error(
ServerError::NoMessagesFound {
why: "no user messages found".to_string(),
},
None,
);
}
};
// add system prompt
match prompt_target.system_prompt.as_ref() {
None => {}
Some(system_prompt) => {
let system_prompt_message = Message {
role: SYSTEM_ROLE.to_string(),
content: Some(system_prompt.clone()),
model: None,
tool_calls: None,
};
messages.push(system_prompt_message);
}
}
let system_prompt = match prompt_target.system_prompt.as_ref() {
None => match self.system_prompt.as_ref() {
None => None,
Some(system_prompt) => Some(system_prompt.clone()),
},
Some(system_prompt) => Some(system_prompt.clone()),
};
// add data from function call response
messages.push({
Message {
role: USER_ROLE.to_string(),
content: Some(body_str),
if system_prompt.is_some() {
let system_prompt_message = Message {
role: SYSTEM_ROLE.to_string(),
content: system_prompt,
model: None,
tool_calls: None,
}
});
};
messages.push(system_prompt_message);
}
let final_prompt = format!(
"{}\nhere is context: {}",
user_message.content.unwrap(),
app_function_call_response_str
);
// add original user prompt
messages.push({
Message {
role: USER_ROLE.to_string(),
content: Some(callout_context.user_message.unwrap()),
content: Some(final_prompt),
model: None,
tool_calls: None,
}