add system prompt (#138)

This commit is contained in:
Adil Hafeez 2024-10-07 17:25:37 -07:00 committed by GitHub
parent c1cfbcd44d
commit 422efd3887
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 47 additions and 23 deletions

View file

@ -48,6 +48,7 @@ pub struct FilterContext {
// callouts stores token_id to request mapping that we use during #on_http_call_response to match the response to the request. // callouts stores token_id to request mapping that we use during #on_http_call_response to match the response to the request.
callouts: RefCell<HashMap<u32, FilterCallContext>>, callouts: RefCell<HashMap<u32, FilterCallContext>>,
overrides: Rc<Option<Overrides>>, overrides: Rc<Option<Overrides>>,
system_prompt: Rc<Option<String>>,
prompt_targets: Rc<HashMap<String, PromptTarget>>, prompt_targets: Rc<HashMap<String, PromptTarget>>,
prompt_guards: Rc<PromptGuards>, prompt_guards: Rc<PromptGuards>,
llm_providers: Option<Rc<LlmProviders>>, llm_providers: Option<Rc<LlmProviders>>,
@ -60,6 +61,7 @@ impl FilterContext {
FilterContext { FilterContext {
callouts: RefCell::new(HashMap::new()), callouts: RefCell::new(HashMap::new()),
metrics: Rc::new(WasmMetrics::new()), metrics: Rc::new(WasmMetrics::new()),
system_prompt: Rc::new(None),
prompt_targets: Rc::new(HashMap::new()), prompt_targets: Rc::new(HashMap::new()),
overrides: Rc::new(None), overrides: Rc::new(None),
prompt_guards: Rc::new(PromptGuards::default()), prompt_guards: Rc::new(PromptGuards::default()),
@ -245,6 +247,7 @@ impl RootContext for FilterContext {
for pt in config.prompt_targets { for pt in config.prompt_targets {
prompt_targets.insert(pt.name.clone(), pt.clone()); prompt_targets.insert(pt.name.clone(), pt.clone());
} }
self.system_prompt = Rc::new(config.system_prompt);
self.prompt_targets = Rc::new(prompt_targets); self.prompt_targets = Rc::new(prompt_targets);
ratelimit::ratelimits(config.ratelimits); ratelimit::ratelimits(config.ratelimits);
@ -273,6 +276,7 @@ impl RootContext for FilterContext {
Some(Box::new(StreamContext::new( Some(Box::new(StreamContext::new(
context_id, context_id,
Rc::clone(&self.metrics), Rc::clone(&self.metrics),
Rc::clone(&self.system_prompt),
Rc::clone(&self.prompt_targets), Rc::clone(&self.prompt_targets),
Rc::clone(&self.prompt_guards), Rc::clone(&self.prompt_guards),
Rc::clone(&self.overrides), Rc::clone(&self.overrides),

View file

@ -80,11 +80,14 @@ pub enum ServerError {
Jailbreak(String), Jailbreak(String),
#[error("{why}")] #[error("{why}")]
BadRequest { why: String }, BadRequest { why: String },
#[error("{why}")]
NoMessagesFound { why: String },
} }
pub struct StreamContext { pub struct StreamContext {
context_id: u32, context_id: u32,
metrics: Rc<WasmMetrics>, metrics: Rc<WasmMetrics>,
system_prompt: Rc<Option<String>>,
prompt_targets: Rc<HashMap<String, PromptTarget>>, prompt_targets: Rc<HashMap<String, PromptTarget>>,
embeddings_store: Rc<EmbeddingsStore>, embeddings_store: Rc<EmbeddingsStore>,
overrides: Rc<Option<Overrides>>, overrides: Rc<Option<Overrides>>,
@ -108,6 +111,7 @@ impl StreamContext {
pub fn new( pub fn new(
context_id: u32, context_id: u32,
metrics: Rc<WasmMetrics>, metrics: Rc<WasmMetrics>,
system_prompt: Rc<Option<String>>,
prompt_targets: Rc<HashMap<String, PromptTarget>>, prompt_targets: Rc<HashMap<String, PromptTarget>>,
prompt_guards: Rc<PromptGuards>, prompt_guards: Rc<PromptGuards>,
overrides: Rc<Option<Overrides>>, overrides: Rc<Option<Overrides>>,
@ -117,6 +121,7 @@ impl StreamContext {
StreamContext { StreamContext {
context_id, context_id,
metrics, metrics,
system_prompt,
prompt_targets, prompt_targets,
embeddings_store, embeddings_store,
callouts: RefCell::new(HashMap::new()), callouts: RefCell::new(HashMap::new()),
@ -633,9 +638,12 @@ impl StreamContext {
} else { } else {
warn!("http status code not found in api response"); warn!("http status code not found in api response");
} }
let body_str: String = String::from_utf8(body).unwrap(); let app_function_call_response_str: String = String::from_utf8(body).unwrap();
self.tool_call_response = Some(body_str.clone()); self.tool_call_response = Some(app_function_call_response_str.clone());
debug!("arch <= app response body: {}", body_str); debug!(
"arch <= app response body: {}",
app_function_call_response_str
);
let prompt_target_name = callout_context.prompt_target_name.unwrap(); let prompt_target_name = callout_context.prompt_target_name.unwrap();
let prompt_target = self let prompt_target = self
.prompt_targets .prompt_targets
@ -644,36 +652,48 @@ impl StreamContext {
.clone(); .clone();
let mut messages: Vec<Message> = callout_context.request_body.messages.clone(); let mut messages: Vec<Message> = callout_context.request_body.messages.clone();
let user_message = match messages.pop() {
Some(user_message) => user_message,
None => {
return self.send_server_error(
ServerError::NoMessagesFound {
why: "no user messages found".to_string(),
},
None,
);
}
};
// add system prompt // add system prompt
match prompt_target.system_prompt.as_ref() { let system_prompt = match prompt_target.system_prompt.as_ref() {
None => {} None => match self.system_prompt.as_ref() {
Some(system_prompt) => { None => None,
Some(system_prompt) => Some(system_prompt.clone()),
},
Some(system_prompt) => Some(system_prompt.clone()),
};
if system_prompt.is_some() {
let system_prompt_message = Message { let system_prompt_message = Message {
role: SYSTEM_ROLE.to_string(), role: SYSTEM_ROLE.to_string(),
content: Some(system_prompt.clone()), content: system_prompt,
model: None, model: None,
tool_calls: None, tool_calls: None,
}; };
messages.push(system_prompt_message); messages.push(system_prompt_message);
} }
}
// add data from function call response let final_prompt = format!(
messages.push({ "{}\nhere is context: {}",
Message { user_message.content.unwrap(),
role: USER_ROLE.to_string(), app_function_call_response_str
content: Some(body_str), );
model: None,
tool_calls: None,
}
});
// add original user prompt // add original user prompt
messages.push({ messages.push({
Message { Message {
role: USER_ROLE.to_string(), role: USER_ROLE.to_string(),
content: Some(callout_context.user_message.unwrap()), content: Some(final_prompt),
model: None, model: None,
tool_calls: None, tool_calls: None,
} }