mirror of
https://github.com/katanemo/plano.git
synced 2026-04-25 00:36:34 +02:00
improve response handling (#71)
This commit is contained in:
parent
79b1c5415f
commit
eff4cd9826
7 changed files with 51 additions and 51 deletions
|
|
@ -33,8 +33,8 @@ def predict(message, history):
|
|||
# remove last user message in case of exception
|
||||
history.pop()
|
||||
log.info("CHAT_COMPLETION_ENDPOINT: ", CHAT_COMPLETION_ENDPOINT)
|
||||
log.info("Error with OpenAI API: {}".format(e.message))
|
||||
raise gr.Error("Error with OpenAI API: {}".format(e.message))
|
||||
log.info("Error calling gateway API: {}".format(e.message))
|
||||
raise gr.Error("Error calling gateway API: {}".format(e.message))
|
||||
|
||||
# for chunk in response:
|
||||
# if chunk.choices[0].delta.content is not None:
|
||||
|
|
|
|||
|
|
@ -40,6 +40,7 @@ prompt_targets:
|
|||
- name: policy_number
|
||||
required: true
|
||||
description: The policy number for which the insurance claim details are requested.
|
||||
type: string
|
||||
- name: include_expired
|
||||
description: Include expired insurance claims in the response.
|
||||
type: string
|
||||
|
|
|
|||
|
|
@ -40,6 +40,7 @@ services:
|
|||
retries: 20
|
||||
volumes:
|
||||
- ~/.cache/huggingface:/root/.cache/huggingface
|
||||
- ./bolt_config.yaml:/root/bolt_config.yaml
|
||||
|
||||
function_resolver:
|
||||
build:
|
||||
|
|
|
|||
|
|
@ -253,6 +253,10 @@ impl RootContext for FilterContext {
|
|||
}
|
||||
|
||||
fn create_http_context(&self, context_id: u32) -> Option<Box<dyn HttpContext>> {
|
||||
debug!(
|
||||
"||| create_http_context called with context_id: {:?} |||",
|
||||
context_id
|
||||
);
|
||||
Some(Box::new(StreamContext::new(
|
||||
context_id,
|
||||
Rc::clone(&self.metrics),
|
||||
|
|
|
|||
|
|
@ -282,10 +282,11 @@ impl StreamContext {
|
|||
+ pred_class_desc_emb_similarity * 0.3;
|
||||
|
||||
debug!(
|
||||
"similarity score: {:.3}, intent score: {:.3}, description embedding score: {:.3}",
|
||||
"similarity score: {:.3}, intent score: {:.3}, description embedding score: {:.3}, prompt: {}",
|
||||
prompt_target_similarity_score,
|
||||
zeroshot_intent_response.predicted_class_score,
|
||||
pred_class_desc_emb_similarity
|
||||
pred_class_desc_emb_similarity,
|
||||
callout_context.user_message.as_ref().unwrap()
|
||||
);
|
||||
|
||||
let prompt_target_name = zeroshot_intent_response.predicted_class.clone();
|
||||
|
|
@ -467,51 +468,7 @@ impl StreamContext {
|
|||
}
|
||||
};
|
||||
|
||||
// prompt target
|
||||
|
||||
let prompt_target = self
|
||||
.prompt_targets
|
||||
.read()
|
||||
.unwrap()
|
||||
.get(callout_context.prompt_target_name.as_ref().unwrap())
|
||||
.unwrap()
|
||||
.clone();
|
||||
|
||||
// // verify required parameters are present
|
||||
// prompt_target
|
||||
// .parameters
|
||||
// .as_ref()
|
||||
// .unwrap()
|
||||
// .iter()
|
||||
// .for_each(|param| match param.required {
|
||||
// None => {}
|
||||
// Some(required) => {
|
||||
// if required
|
||||
// && !tools_call_response.tool_calls[0]
|
||||
// .arguments
|
||||
// .contains_key(¶m.name)
|
||||
// {
|
||||
// self.send_server_error(
|
||||
// format!(
|
||||
// "missing required parameter: {}, for target: {}",
|
||||
// param.name, prompt_target.name
|
||||
// ),
|
||||
// Some(StatusCode::BAD_REQUEST),
|
||||
// )
|
||||
// }
|
||||
// }
|
||||
// });
|
||||
|
||||
debug!("tool_call_details: {:?}", tools_call_response);
|
||||
let tool_name = &tools_call_response.tool_calls[0].name;
|
||||
|
||||
// ensure that detected tool name matches the prompt target name
|
||||
if tool_name != &prompt_target.name {
|
||||
warn!(
|
||||
"tool name mismatch: detected tool name: {}, expected tool name: {}",
|
||||
tool_name, &prompt_target.name
|
||||
);
|
||||
}
|
||||
// extract all tool names
|
||||
let tool_names: Vec<String> = tools_call_response
|
||||
.tool_calls
|
||||
|
|
@ -519,12 +476,26 @@ impl StreamContext {
|
|||
.map(|tool_call| tool_call.name.clone())
|
||||
.collect();
|
||||
|
||||
debug!(
|
||||
"call context similarity score: {:?}",
|
||||
callout_context.similarity_scores
|
||||
);
|
||||
//HACK: for now we only support one tool call, we will support multiple tool calls in the future
|
||||
let tool_params = &tools_call_response.tool_calls[0].arguments;
|
||||
let tools_call_name = tools_call_response.tool_calls[0].name.clone();
|
||||
let tool_params_json_str = serde_json::to_string(&tool_params).unwrap();
|
||||
|
||||
let prompt_target = self
|
||||
.prompt_targets
|
||||
.read()
|
||||
.unwrap()
|
||||
.get(&tools_call_name)
|
||||
.unwrap()
|
||||
.clone();
|
||||
|
||||
debug!("prompt_target_name: {}", prompt_target.name);
|
||||
debug!("tool_name(s): {:?}", tool_names);
|
||||
debug!("tool_params: {}", tool_params_json_str);
|
||||
debug!("prompt_target_name: {}", prompt_target.name);
|
||||
|
||||
let endpoint = prompt_target.endpoint.as_ref().unwrap();
|
||||
let token_id = match self.dispatch_http_call(
|
||||
|
|
@ -554,6 +525,19 @@ impl StreamContext {
|
|||
}
|
||||
|
||||
fn function_call_response_handler(&mut self, body: Vec<u8>, callout_context: CallContext) {
|
||||
let headers = self.get_http_call_response_headers();
|
||||
debug!("response headers: {:?}", headers);
|
||||
if let Some(http_status) = headers.iter().find(|(key, _)| key == ":status") {
|
||||
if http_status.1 != StatusCode::OK.as_str() {
|
||||
let error_msg = format!(
|
||||
"Error in function call response: status code: {}",
|
||||
http_status.1
|
||||
);
|
||||
return self.send_server_error(error_msg, Some(StatusCode::BAD_REQUEST));
|
||||
}
|
||||
} else {
|
||||
warn!("http status code not found in api response");
|
||||
}
|
||||
debug!("response received for function call response");
|
||||
let body_str: String = String::from_utf8(body).unwrap();
|
||||
debug!("function_call_response response str: {:?}", body_str);
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ fn wasm_module() -> String {
|
|||
fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) {
|
||||
module
|
||||
.call_proxy_on_context_create(http_context, filter_context)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
||||
|
|
@ -231,6 +232,7 @@ fn successful_request_to_open_ai_chat_completions() {
|
|||
|
||||
module
|
||||
.call_proxy_on_context_create(http_context, root_context)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
||||
|
|
@ -318,6 +320,7 @@ fn bad_request_to_open_ai_chat_completions() {
|
|||
|
||||
module
|
||||
.call_proxy_on_context_create(http_context, root_context)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
||||
|
|
@ -453,6 +456,7 @@ fn request_ratelimited() {
|
|||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_http_call(Some("weatherhost"), None, None, None, None)
|
||||
.returning(Some(4))
|
||||
.expect_metric_increment("active_http_calls", 1)
|
||||
|
|
@ -466,6 +470,9 @@ fn request_ratelimited() {
|
|||
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
|
||||
.returning(Some(&body_text))
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Warn), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
|
|
@ -566,6 +573,7 @@ fn request_not_ratelimited() {
|
|||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_http_call(Some("weatherhost"), None, None, None, None)
|
||||
.returning(Some(4))
|
||||
.expect_metric_increment("active_http_calls", 1)
|
||||
|
|
@ -579,12 +587,14 @@ fn request_not_ratelimited() {
|
|||
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
|
||||
.returning(Some(&body_text))
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Warn), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None)
|
||||
// .expect_log(Some(LogLevel::Debug), None)
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -62,7 +62,7 @@ if "prompt_guards" in config.keys():
|
|||
toxic_model = None
|
||||
|
||||
|
||||
guard_handler = GuardHandler(toxic_model, jailbreak_model)
|
||||
guard_handler = GuardHandler(toxic_model, jailbreak_model)
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue