diff --git a/crates/prompt_gateway/src/stream_context.rs b/crates/prompt_gateway/src/stream_context.rs index 3a7e9a9b..9e8f8a60 100644 --- a/crates/prompt_gateway/src/stream_context.rs +++ b/crates/prompt_gateway/src/stream_context.rs @@ -446,95 +446,91 @@ impl StreamContext { // it may be that arch fc is handling the conversation for parameter collection if arch_assistant { info!("arch fc is engaged in parameter collection"); - } else { - if let Some(default_prompt_target) = self - .prompt_targets - .values() - .find(|pt| pt.default.unwrap_or(false)) - { - debug!( - "default prompt target found, forwarding request to default prompt target" - ); - let endpoint = default_prompt_target.endpoint.clone().unwrap(); - let upstream_path: String = endpoint.path.unwrap_or(String::from("/")); + } else if let Some(default_prompt_target) = self + .prompt_targets + .values() + .find(|pt| pt.default.unwrap_or(false)) + { + debug!("default prompt target found, forwarding request to default prompt target"); + let endpoint = default_prompt_target.endpoint.clone().unwrap(); + let upstream_path: String = endpoint.path.unwrap_or(String::from("/")); - let upstream_endpoint = endpoint.name; - let mut params = HashMap::new(); - params.insert( - MESSAGES_KEY.to_string(), - callout_context.request_body.messages.clone(), - ); - let arch_messages_json = serde_json::to_string(¶ms).unwrap(); - let timeout_str = ARCH_FC_REQUEST_TIMEOUT_MS.to_string(); + let upstream_endpoint = endpoint.name; + let mut params = HashMap::new(); + params.insert( + MESSAGES_KEY.to_string(), + callout_context.request_body.messages.clone(), + ); + let arch_messages_json = serde_json::to_string(¶ms).unwrap(); + let timeout_str = ARCH_FC_REQUEST_TIMEOUT_MS.to_string(); - let mut headers = vec![ - (":method", "POST"), - (ARCH_UPSTREAM_HOST_HEADER, &upstream_endpoint), - (":path", &upstream_path), - (":authority", &upstream_endpoint), - ("content-type", "application/json"), - ("x-envoy-max-retries", "3"), - ("x-envoy-upstream-rq-timeout-ms", timeout_str.as_str()), - ]; + let mut headers = vec![ + (":method", "POST"), + (ARCH_UPSTREAM_HOST_HEADER, &upstream_endpoint), + (":path", &upstream_path), + (":authority", &upstream_endpoint), + ("content-type", "application/json"), + ("x-envoy-max-retries", "3"), + ("x-envoy-upstream-rq-timeout-ms", timeout_str.as_str()), + ]; - if self.request_id.is_some() { - headers.push((REQUEST_ID_HEADER, self.request_id.as_ref().unwrap())); - } - - let call_args = CallArgs::new( - ARCH_INTERNAL_CLUSTER_NAME, - &upstream_path, - headers, - Some(arch_messages_json.as_bytes()), - vec![], - Duration::from_secs(5), - ); - callout_context.response_handler_type = ResponseHandlerType::DefaultTarget; - callout_context.prompt_target_name = Some(default_prompt_target.name.clone()); - - if let Err(e) = self.http_call(call_args, callout_context) { - warn!("error dispatching default prompt target request: {}", e); - return self.send_server_error( - ServerError::HttpDispatch(e), - Some(StatusCode::BAD_REQUEST), - ); - } - return; - } else { - // if no default prompt target is found and similarity score is low send response to upstream llm - // removing tool calls and tool response - - let messages = self.filter_out_arch_messages(&callout_context); - - let chat_completions_request: ChatCompletionsRequest = ChatCompletionsRequest { - model: callout_context.request_body.model, - messages, - tools: None, - stream: callout_context.request_body.stream, - stream_options: callout_context.request_body.stream_options, - metadata: None, - }; - - let llm_request_str = match serde_json::to_string(&chat_completions_request) { - Ok(json_string) => json_string, - Err(e) => { - return self.send_server_error(ServerError::Serialization(e), None); - } - }; - debug!( - "archgw (low similarity score) => llm request: {}", - llm_request_str - ); - - self.set_http_request_body( - 0, - self.request_body_size, - &llm_request_str.into_bytes(), - ); - - self.resume_http_request(); - return; + if self.request_id.is_some() { + headers.push((REQUEST_ID_HEADER, self.request_id.as_ref().unwrap())); } + + let call_args = CallArgs::new( + ARCH_INTERNAL_CLUSTER_NAME, + &upstream_path, + headers, + Some(arch_messages_json.as_bytes()), + vec![], + Duration::from_secs(5), + ); + callout_context.response_handler_type = ResponseHandlerType::DefaultTarget; + callout_context.prompt_target_name = Some(default_prompt_target.name.clone()); + + if let Err(e) = self.http_call(call_args, callout_context) { + warn!("error dispatching default prompt target request: {}", e); + return self.send_server_error( + ServerError::HttpDispatch(e), + Some(StatusCode::BAD_REQUEST), + ); + } + return; + } else { + // if no default prompt target is found and similarity score is low send response to upstream llm + // removing tool calls and tool response + + let messages = self.filter_out_arch_messages(&callout_context); + + let chat_completions_request: ChatCompletionsRequest = ChatCompletionsRequest { + model: callout_context.request_body.model, + messages, + tools: None, + stream: callout_context.request_body.stream, + stream_options: callout_context.request_body.stream_options, + metadata: None, + }; + + let llm_request_str = match serde_json::to_string(&chat_completions_request) { + Ok(json_string) => json_string, + Err(e) => { + return self.send_server_error(ServerError::Serialization(e), None); + } + }; + debug!( + "archgw (low similarity score) => llm request: {}", + llm_request_str + ); + + self.set_http_request_body( + 0, + self.request_body_size, + &llm_request_str.into_bytes(), + ); + + self.resume_http_request(); + return; } } @@ -959,17 +955,22 @@ impl StreamContext { } fn filter_out_arch_messages(&mut self, callout_context: &StreamCallContext) -> Vec { - let mut messages: Vec = Vec::new(); // add system prompt + let system_prompt = match callout_context.prompt_target_name.as_ref() { None => self.system_prompt.as_ref().clone(), Some(prompt_target_name) => { - self.prompt_targets + let prompt_system_prompt = self + .prompt_targets .get(prompt_target_name) .unwrap() .clone() - .system_prompt + .system_prompt; + match prompt_system_prompt { + None => self.system_prompt.as_ref().clone(), + Some(system_prompt) => Some(system_prompt), + } } }; if system_prompt.is_some() { diff --git a/demos/hr_agent/arch_config.yaml b/demos/hr_agent/arch_config.yaml index 2b56cb6e..429a4e84 100644 --- a/demos/hr_agent/arch_config.yaml +++ b/demos/hr_agent/arch_config.yaml @@ -9,7 +9,7 @@ llm_providers: - name: OpenAI provider: openai access_key: $OPENAI_API_KEY - model: gpt-4o + model: gpt-4o-mini default: true # Arch creates a round-robin load balancing between different endpoints, managed via the cluster subsystem. @@ -26,8 +26,7 @@ endpoints: system_prompt: | You are a Workforce assistant that helps on workforce planning and HR decision makers with reporting and workforce planning. Use following rules when responding, - when you get data in json format, offer some summary but don't be too verbose - - be concise and to the point - - if you don't have data, say so and offer to help with something else + - be concise, to the point and do not over analyze the data prompt_targets: - name: workforce diff --git a/demos/hr_agent/main.py b/demos/hr_agent/main.py index 9c49d120..e1c4bbe3 100644 --- a/demos/hr_agent/main.py +++ b/demos/hr_agent/main.py @@ -30,7 +30,7 @@ with open("workforce_data.json") as file: # Define the request model -class WorkforceRequset(BaseModel): +class WorkforceRequest(BaseModel): region: str staffing_type: str data_snapshot_days_ago: Optional[int] = None @@ -74,7 +74,7 @@ def send_slack_message(request: SlackRequest): # Post method for device summary @app.post("/agent/workforce") -def get_workforce(request: WorkforceRequset): +def get_workforce(request: WorkforceRequest): """ Endpoint to workforce data by region, staffing type at a given point in time. """ @@ -90,7 +90,7 @@ def get_workforce(request: WorkforceRequset): "region": region, "staffing_type": f"Staffing agency: {staffing_type}", "headcount": f"Headcount: {int(workforce_data_df[(workforce_data_df['region']==region) & (workforce_data_df['data_snapshot_days_ago']==data_snapshot_days_ago)][staffing_type].values[0])}", - "satisfaction": f"Satisifaction: {float(workforce_data_df[(workforce_data_df['region']==region) & (workforce_data_df['data_snapshot_days_ago']==data_snapshot_days_ago)]['satisfaction'].values[0])}", + "satisfaction": f"Satisfaction: {float(workforce_data_df[(workforce_data_df['region']==region) & (workforce_data_df['data_snapshot_days_ago']==data_snapshot_days_ago)]['satisfaction'].values[0])}", } return response