diff --git a/chatbot_ui/app/run.py b/chatbot_ui/app/run.py index 380d0ddd..034378a7 100644 --- a/chatbot_ui/app/run.py +++ b/chatbot_ui/app/run.py @@ -33,8 +33,8 @@ def predict(message, history): # remove last user message in case of exception history.pop() log.info("CHAT_COMPLETION_ENDPOINT: ", CHAT_COMPLETION_ENDPOINT) - log.info("Error with OpenAI API: {}".format(e.message)) - raise gr.Error("Error with OpenAI API: {}".format(e.message)) + log.info("Error calling gateway API: {}".format(e.message)) + raise gr.Error("Error calling gateway API: {}".format(e.message)) # for chunk in response: # if chunk.choices[0].delta.content is not None: diff --git a/demos/function_calling/bolt_config.yaml b/demos/function_calling/bolt_config.yaml index 11f31472..9b99364b 100644 --- a/demos/function_calling/bolt_config.yaml +++ b/demos/function_calling/bolt_config.yaml @@ -40,6 +40,7 @@ prompt_targets: - name: policy_number required: true description: The policy number for which the insurance claim details are requested. + type: string - name: include_expired description: Include expired insurance claims in the response. type: string diff --git a/demos/function_calling/docker-compose.yaml b/demos/function_calling/docker-compose.yaml index fdb0175f..5e05ae3b 100644 --- a/demos/function_calling/docker-compose.yaml +++ b/demos/function_calling/docker-compose.yaml @@ -40,6 +40,7 @@ services: retries: 20 volumes: - ~/.cache/huggingface:/root/.cache/huggingface + - ./bolt_config.yaml:/root/bolt_config.yaml function_resolver: build: diff --git a/envoyfilter/src/filter_context.rs b/envoyfilter/src/filter_context.rs index be0f15d3..7cf1316d 100644 --- a/envoyfilter/src/filter_context.rs +++ b/envoyfilter/src/filter_context.rs @@ -253,6 +253,10 @@ impl RootContext for FilterContext { } fn create_http_context(&self, context_id: u32) -> Option> { + debug!( + "||| create_http_context called with context_id: {:?} |||", + context_id + ); Some(Box::new(StreamContext::new( context_id, Rc::clone(&self.metrics), diff --git a/envoyfilter/src/stream_context.rs b/envoyfilter/src/stream_context.rs index 371672a4..80e98fc6 100644 --- a/envoyfilter/src/stream_context.rs +++ b/envoyfilter/src/stream_context.rs @@ -282,10 +282,11 @@ impl StreamContext { + pred_class_desc_emb_similarity * 0.3; debug!( - "similarity score: {:.3}, intent score: {:.3}, description embedding score: {:.3}", + "similarity score: {:.3}, intent score: {:.3}, description embedding score: {:.3}, prompt: {}", prompt_target_similarity_score, zeroshot_intent_response.predicted_class_score, - pred_class_desc_emb_similarity + pred_class_desc_emb_similarity, + callout_context.user_message.as_ref().unwrap() ); let prompt_target_name = zeroshot_intent_response.predicted_class.clone(); @@ -467,51 +468,7 @@ impl StreamContext { } }; - // prompt target - - let prompt_target = self - .prompt_targets - .read() - .unwrap() - .get(callout_context.prompt_target_name.as_ref().unwrap()) - .unwrap() - .clone(); - - // // verify required parameters are present - // prompt_target - // .parameters - // .as_ref() - // .unwrap() - // .iter() - // .for_each(|param| match param.required { - // None => {} - // Some(required) => { - // if required - // && !tools_call_response.tool_calls[0] - // .arguments - // .contains_key(¶m.name) - // { - // self.send_server_error( - // format!( - // "missing required parameter: {}, for target: {}", - // param.name, prompt_target.name - // ), - // Some(StatusCode::BAD_REQUEST), - // ) - // } - // } - // }); - debug!("tool_call_details: {:?}", tools_call_response); - let tool_name = &tools_call_response.tool_calls[0].name; - - // ensure that detected tool name matches the prompt target name - if tool_name != &prompt_target.name { - warn!( - "tool name mismatch: detected tool name: {}, expected tool name: {}", - tool_name, &prompt_target.name - ); - } // extract all tool names let tool_names: Vec = tools_call_response .tool_calls @@ -519,12 +476,26 @@ impl StreamContext { .map(|tool_call| tool_call.name.clone()) .collect(); + debug!( + "call context similarity score: {:?}", + callout_context.similarity_scores + ); + //HACK: for now we only support one tool call, we will support multiple tool calls in the future let tool_params = &tools_call_response.tool_calls[0].arguments; + let tools_call_name = tools_call_response.tool_calls[0].name.clone(); let tool_params_json_str = serde_json::to_string(&tool_params).unwrap(); + let prompt_target = self + .prompt_targets + .read() + .unwrap() + .get(&tools_call_name) + .unwrap() + .clone(); + + debug!("prompt_target_name: {}", prompt_target.name); debug!("tool_name(s): {:?}", tool_names); debug!("tool_params: {}", tool_params_json_str); - debug!("prompt_target_name: {}", prompt_target.name); let endpoint = prompt_target.endpoint.as_ref().unwrap(); let token_id = match self.dispatch_http_call( @@ -554,6 +525,19 @@ impl StreamContext { } fn function_call_response_handler(&mut self, body: Vec, callout_context: CallContext) { + let headers = self.get_http_call_response_headers(); + debug!("response headers: {:?}", headers); + if let Some(http_status) = headers.iter().find(|(key, _)| key == ":status") { + if http_status.1 != StatusCode::OK.as_str() { + let error_msg = format!( + "Error in function call response: status code: {}", + http_status.1 + ); + return self.send_server_error(error_msg, Some(StatusCode::BAD_REQUEST)); + } + } else { + warn!("http status code not found in api response"); + } debug!("response received for function call response"); let body_str: String = String::from_utf8(body).unwrap(); debug!("function_call_response response str: {:?}", body_str); diff --git a/envoyfilter/tests/integration.rs b/envoyfilter/tests/integration.rs index 32493df1..6a418c56 100644 --- a/envoyfilter/tests/integration.rs +++ b/envoyfilter/tests/integration.rs @@ -29,6 +29,7 @@ fn wasm_module() -> String { fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) { module .call_proxy_on_context_create(http_context, filter_context) + .expect_log(Some(LogLevel::Debug), None) .execute_and_expect(ReturnType::None) .unwrap(); @@ -231,6 +232,7 @@ fn successful_request_to_open_ai_chat_completions() { module .call_proxy_on_context_create(http_context, root_context) + .expect_log(Some(LogLevel::Debug), None) .execute_and_expect(ReturnType::None) .unwrap(); @@ -318,6 +320,7 @@ fn bad_request_to_open_ai_chat_completions() { module .call_proxy_on_context_create(http_context, root_context) + .expect_log(Some(LogLevel::Debug), None) .execute_and_expect(ReturnType::None) .unwrap(); @@ -453,6 +456,7 @@ fn request_ratelimited() { .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Debug), None) .expect_http_call(Some("weatherhost"), None, None, None, None) .returning(Some(4)) .expect_metric_increment("active_http_calls", 1) @@ -466,6 +470,9 @@ fn request_ratelimited() { .expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody)) .returning(Some(&body_text)) .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Warn), None) + .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None) @@ -566,6 +573,7 @@ fn request_not_ratelimited() { .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Debug), None) .expect_http_call(Some("weatherhost"), None, None, None, None) .returning(Some(4)) .expect_metric_increment("active_http_calls", 1) @@ -579,12 +587,14 @@ fn request_not_ratelimited() { .expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody)) .returning(Some(&body_text)) .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Warn), None) + .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None) .expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None) - // .expect_log(Some(LogLevel::Debug), None) .execute_and_expect(ReturnType::None) .unwrap(); } diff --git a/model_server/app/main.py b/model_server/app/main.py index e5988fef..4eea7b8a 100644 --- a/model_server/app/main.py +++ b/model_server/app/main.py @@ -62,7 +62,7 @@ if "prompt_guards" in config.keys(): toxic_model = None -guard_handler = GuardHandler(toxic_model, jailbreak_model) + guard_handler = GuardHandler(toxic_model, jailbreak_model) app = FastAPI()