diff --git a/arch/src/filter_context.rs b/arch/src/filter_context.rs index fa0f29fc..e0a80596 100644 --- a/arch/src/filter_context.rs +++ b/arch/src/filter_context.rs @@ -48,6 +48,7 @@ pub struct FilterContext { // callouts stores token_id to request mapping that we use during #on_http_call_response to match the response to the request. callouts: RefCell>, overrides: Rc>, + system_prompt: Rc>, prompt_targets: Rc>, prompt_guards: Rc, llm_providers: Option>, @@ -60,6 +61,7 @@ impl FilterContext { FilterContext { callouts: RefCell::new(HashMap::new()), metrics: Rc::new(WasmMetrics::new()), + system_prompt: Rc::new(None), prompt_targets: Rc::new(HashMap::new()), overrides: Rc::new(None), prompt_guards: Rc::new(PromptGuards::default()), @@ -245,6 +247,7 @@ impl RootContext for FilterContext { for pt in config.prompt_targets { prompt_targets.insert(pt.name.clone(), pt.clone()); } + self.system_prompt = Rc::new(config.system_prompt); self.prompt_targets = Rc::new(prompt_targets); ratelimit::ratelimits(config.ratelimits); @@ -273,6 +276,7 @@ impl RootContext for FilterContext { Some(Box::new(StreamContext::new( context_id, Rc::clone(&self.metrics), + Rc::clone(&self.system_prompt), Rc::clone(&self.prompt_targets), Rc::clone(&self.prompt_guards), Rc::clone(&self.overrides), diff --git a/arch/src/stream_context.rs b/arch/src/stream_context.rs index 7cd02734..567fb186 100644 --- a/arch/src/stream_context.rs +++ b/arch/src/stream_context.rs @@ -80,11 +80,14 @@ pub enum ServerError { Jailbreak(String), #[error("{why}")] BadRequest { why: String }, + #[error("{why}")] + NoMessagesFound { why: String }, } pub struct StreamContext { context_id: u32, metrics: Rc, + system_prompt: Rc>, prompt_targets: Rc>, embeddings_store: Rc, overrides: Rc>, @@ -108,6 +111,7 @@ impl StreamContext { pub fn new( context_id: u32, metrics: Rc, + system_prompt: Rc>, prompt_targets: Rc>, prompt_guards: Rc, overrides: Rc>, @@ -117,6 +121,7 @@ impl StreamContext { StreamContext { context_id, metrics, + system_prompt, prompt_targets, embeddings_store, callouts: RefCell::new(HashMap::new()), @@ -633,9 +638,12 @@ impl StreamContext { } else { warn!("http status code not found in api response"); } - let body_str: String = String::from_utf8(body).unwrap(); - self.tool_call_response = Some(body_str.clone()); - debug!("arch <= app response body: {}", body_str); + let app_function_call_response_str: String = String::from_utf8(body).unwrap(); + self.tool_call_response = Some(app_function_call_response_str.clone()); + debug!( + "arch <= app response body: {}", + app_function_call_response_str + ); let prompt_target_name = callout_context.prompt_target_name.unwrap(); let prompt_target = self .prompt_targets @@ -644,36 +652,48 @@ impl StreamContext { .clone(); let mut messages: Vec = callout_context.request_body.messages.clone(); + let user_message = match messages.pop() { + Some(user_message) => user_message, + None => { + return self.send_server_error( + ServerError::NoMessagesFound { + why: "no user messages found".to_string(), + }, + None, + ); + } + }; // add system prompt - match prompt_target.system_prompt.as_ref() { - None => {} - Some(system_prompt) => { - let system_prompt_message = Message { - role: SYSTEM_ROLE.to_string(), - content: Some(system_prompt.clone()), - model: None, - tool_calls: None, - }; - messages.push(system_prompt_message); - } - } + let system_prompt = match prompt_target.system_prompt.as_ref() { + None => match self.system_prompt.as_ref() { + None => None, + Some(system_prompt) => Some(system_prompt.clone()), + }, + Some(system_prompt) => Some(system_prompt.clone()), + }; - // add data from function call response - messages.push({ - Message { - role: USER_ROLE.to_string(), - content: Some(body_str), + if system_prompt.is_some() { + let system_prompt_message = Message { + role: SYSTEM_ROLE.to_string(), + content: system_prompt, model: None, tool_calls: None, - } - }); + }; + messages.push(system_prompt_message); + } + + let final_prompt = format!( + "{}\nhere is context: {}", + user_message.content.unwrap(), + app_function_call_response_str + ); // add original user prompt messages.push({ Message { role: USER_ROLE.to_string(), - content: Some(callout_context.user_message.unwrap()), + content: Some(final_prompt), model: None, tool_calls: None, }