diff --git a/demos/function_calling/bolt_config.yaml b/demos/function_calling/bolt_config.yaml index a3d115dd..1c5b1a56 100644 --- a/demos/function_calling/bolt_config.yaml +++ b/demos/function_calling/bolt_config.yaml @@ -23,10 +23,6 @@ prompt_targets: - type: function_resolver name: weather_forecast description: This function resolver provides weather forecast information for a given city. - few_shot_examples: - - what is the weather in New York? - - how is the weather in San Francisco? - - what is the forecast in Chicago? parameters: - name: city required: true @@ -42,3 +38,16 @@ prompt_targets: You are a helpful weather forecaster. Use weater data that is provided to you. Please following following guidelines when responding to user queries: - Use farenheight for temperature - Use miles per hour for wind speed + - type: function_resolver + name: insurance_claim_details + description: This function resolver provides insurance claim details for a given policy number. + parameters: + - name: policy_number + required: true + description: The policy number for which the insurance claim details are requested. + endpoint: + cluster: weatherhost + path: /insurance_claim_details + system_prompt: | + You are a helpful insurance claim details provider. Use insurance claim data that is provided to you. Please following following guidelines when responding to user queries: + - Use policy number to retrieve insurance claim details diff --git a/envoyfilter/src/stream_context.rs b/envoyfilter/src/stream_context.rs index a2b4586d..880c1498 100644 --- a/envoyfilter/src/stream_context.rs +++ b/envoyfilter/src/stream_context.rs @@ -41,7 +41,7 @@ enum ResponseHandlerType { pub struct CallContext { response_handler_type: ResponseHandlerType, user_message: Option, - prompt_target: Option, + prompt_target_name: Option, request_body: ChatCompletionsRequest, similarity_scores: Option>, } @@ -325,46 +325,47 @@ impl StreamContext { .unwrap() .clone(); - info!( - "prompt_target name: {:?}, type: {:?}", - prompt_target.name, prompt_target.prompt_type - ); + info!("prompt_target name: {:?}", prompt_target_name); match prompt_target.prompt_type { PromptType::FunctionResolver => { - // only extract entity names - let properties: HashMap = match prompt_target.parameters { - // Clone is unavoidable here because we don't want to move the values out of the prompt target struct. - Some(ref entities) => { - let mut properties: HashMap = HashMap::new(); - for entity in entities.iter() { - let param = ToolParameter { - parameter_type: entity.parameter_type.clone(), - description: entity.description.clone(), - required: entity.required, - enum_values: entity.enum_values.clone(), - }; - properties.insert(entity.name.clone(), param); - } - properties - } - None => HashMap::new(), - }; - let tools_parameters = ToolParameters { - parameters_type: "dict".to_string(), - properties, - }; + let mut tools_definitions: Vec = Vec::new(); - let tools_defintion: ToolsDefinition = ToolsDefinition { - name: prompt_target.name.clone(), - description: prompt_target.description.clone(), - parameters: tools_parameters, - }; + for pt in self.prompt_targets.read().unwrap().values() { + // only extract entity names + let properties: HashMap = match pt.parameters { + // Clone is unavoidable here because we don't want to move the values out of the prompt target struct. + Some(ref entities) => { + let mut properties: HashMap = HashMap::new(); + for entity in entities.iter() { + let param = ToolParameter { + parameter_type: entity.parameter_type.clone(), + description: entity.description.clone(), + required: entity.required, + enum_values: entity.enum_values.clone(), + }; + properties.insert(entity.name.clone(), param); + } + properties + } + None => HashMap::new(), + }; + let tools_parameters = ToolParameters { + parameters_type: "dict".to_string(), + properties, + }; + + tools_definitions.push(ToolsDefinition { + name: pt.name.clone(), + description: pt.description.clone(), + parameters: tools_parameters, + }); + } let chat_completions = ChatCompletionsRequest { model: GPT_35_TURBO.to_string(), messages: callout_context.request_body.messages.clone(), - tools: Some(vec![tools_defintion]), + tools: Some(tools_definitions), stream: false, stream_options: None, }; @@ -411,7 +412,7 @@ impl StreamContext { ); callout_context.response_handler_type = ResponseHandlerType::FunctionResolver; - callout_context.prompt_target = Some(prompt_target); + callout_context.prompt_target_name = Some(prompt_target.name); if self.callouts.insert(token_id, callout_context).is_some() { panic!("duplicate token_id") } @@ -438,7 +439,7 @@ impl StreamContext { // Let's send the response back to the user to initalize lightweight dialog for parameter collection // add resolver name to the response so the client can send the response back to the correct resolver - boltfc_response.resolver_name = Some(callout_context.prompt_target.unwrap().name); + boltfc_response.resolver_name = Some(callout_context.prompt_target_name.unwrap()); info!("some requred parameters are missing, sending response from Bolt FC back to user for parameter collection: {}", e); let bolt_fc_dialogue_message = serde_json::to_string(&boltfc_response).unwrap(); self.send_http_response( @@ -450,11 +451,18 @@ impl StreamContext { } }; - // verify required parameters are present - callout_context - .prompt_target - .as_ref() + // prompt target + + let prompt_target = self + .prompt_targets + .read() .unwrap() + .get(callout_context.prompt_target_name.as_ref().unwrap()) + .unwrap() + .clone(); + // verify required parameters are present + + prompt_target .parameters .as_ref() .unwrap() @@ -477,10 +485,17 @@ impl StreamContext { debug!("tool_call_details: {:?}", tools_call_response); let tool_name = &tools_call_response.tool_calls[0].name; + + // ensure that detected tool name matches the prompt target name + if tool_name != &prompt_target.name { + warn!( + "tool name mismatch: detected tool name: {}, expected tool name: {}", + tool_name, &prompt_target.name + ); + } let tool_params = &tools_call_response.tool_calls[0].arguments; debug!("tool_name: {:?}", tool_name); debug!("tool_params: {:?}", tool_params); - let prompt_target = callout_context.prompt_target.as_ref().unwrap(); debug!("prompt_target: {:?}", prompt_target); let tool_params_json_str = serde_json::to_string(&tool_params).unwrap(); @@ -516,7 +531,14 @@ impl StreamContext { debug!("response received for function call response"); let body_str: String = String::from_utf8(body).unwrap(); debug!("function_call_response response str: {:?}", body_str); - let prompt_target = callout_context.prompt_target.as_ref().unwrap(); + let prompt_target_name = callout_context.prompt_target_name.unwrap(); + let prompt_target = self + .prompt_targets + .read() + .unwrap() + .get(&prompt_target_name) + .unwrap() + .clone(); let mut messages: Vec = callout_context.request_body.messages.clone(); @@ -714,7 +736,7 @@ impl HttpContext for StreamContext { let call_context = CallContext { response_handler_type: ResponseHandlerType::GetEmbeddings, user_message: Some(user_message), - prompt_target: None, + prompt_target_name: None, request_body: deserialized_body, similarity_scores: None, }; diff --git a/model_server/app/main.py b/model_server/app/main.py index 37a5fd4b..1bdb0352 100644 --- a/model_server/app/main.py +++ b/model_server/app/main.py @@ -143,3 +143,19 @@ async def weather(req: WeatherRequest, res: Response): }) return weather_forecast + +class InsuranceClaimDetailsRequest(BaseModel): + policy_number: str + +@app.post("/insurance_claim_details") +async def insurance_claim_details(req: InsuranceClaimDetailsRequest, res: Response): + + claim_details = { + "policy_number": req.policy_number, + "claim_status": "Approved", + "claim_amount": random.randrange(1000, 10000), + "claim_date": str(date.today() - timedelta(days=random.randrange(1, 30))), + "claim_reason": "Car Accident", + } + + return claim_details