diff --git a/demos/function_calling/bolt_config.yaml b/demos/function_calling/bolt_config.yaml index 5d7c6e3a..d936213f 100644 --- a/demos/function_calling/bolt_config.yaml +++ b/demos/function_calling/bolt_config.yaml @@ -40,6 +40,10 @@ prompt_targets: - name: policy_number required: true description: The policy number for which the insurance claim details are requested. + - name: include_expired + description: Include expired insurance claims in the response. + type: string + default: "false" endpoint: cluster: weatherhost path: /insurance_claim_details diff --git a/envoyfilter/src/stream_context.rs b/envoyfilter/src/stream_context.rs index 6e85180a..557f0c90 100644 --- a/envoyfilter/src/stream_context.rs +++ b/envoyfilter/src/stream_context.rs @@ -343,6 +343,7 @@ impl StreamContext { description: entity.description.clone(), required: entity.required, enum_values: entity.enum_values.clone(), + default: entity.default.clone(), }; properties.insert(entity.name.clone(), param); } @@ -425,7 +426,7 @@ impl StreamContext { debug!("response received for function resolver"); let body_str = String::from_utf8(body).unwrap(); - debug!("function_resolver response str: {:?}", body_str); + debug!("function_resolver response str: {}", body_str); let boltfc_response: ChatCompletionsResponse = serde_json::from_str(&body_str).unwrap(); @@ -459,8 +460,8 @@ impl StreamContext { .get(callout_context.prompt_target_name.as_ref().unwrap()) .unwrap() .clone(); - // verify required parameters are present + // verify required parameters are present prompt_target .parameters .as_ref() @@ -492,13 +493,20 @@ impl StreamContext { tool_name, &prompt_target.name ); } - let tool_params = &tools_call_response.tool_calls[0].arguments; - debug!("tool_name: {:?}", tool_name); - debug!("tool_params: {:?}", tool_params); - debug!("prompt_target: {:?}", prompt_target); + // extract all tool names + let tool_names: Vec = tools_call_response + .tool_calls + .iter() + .map(|tool_call| tool_call.name.clone()) + .collect(); + let tool_params = &tools_call_response.tool_calls[0].arguments; let tool_params_json_str = serde_json::to_string(&tool_params).unwrap(); + debug!("tool_name(s): {:?}", tool_names); + debug!("tool_params: {}", tool_params_json_str); + debug!("prompt_target_name: {}", prompt_target.name); + let endpoint = prompt_target.endpoint.as_ref().unwrap(); let token_id = match self.dispatch_http_call( &endpoint.cluster, diff --git a/public_types/src/common_types.rs b/public_types/src/common_types.rs index ead35d10..1b0d2f6f 100644 --- a/public_types/src/common_types.rs +++ b/public_types/src/common_types.rs @@ -45,6 +45,8 @@ pub struct ToolParameter { #[serde(skip_serializing_if = "Option::is_none")] #[serde(rename = "enum")] pub enum_values: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub default: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] diff --git a/public_types/src/configuration.rs b/public_types/src/configuration.rs index 3a0de1d9..7cd9ad59 100644 --- a/public_types/src/configuration.rs +++ b/public_types/src/configuration.rs @@ -105,6 +105,7 @@ pub struct Parameter { pub required: Option, #[serde(rename = "enum")] pub enum_values: Option>, + pub default: Option, } #[derive(Debug, Clone, Serialize, Deserialize)]