improve response handling (#71)

This commit is contained in:
Adil Hafeez 2024-09-23 22:56:35 -07:00 committed by GitHub
parent 79b1c5415f
commit eff4cd9826
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 51 additions and 51 deletions

View file

@ -253,6 +253,10 @@ impl RootContext for FilterContext {
}
fn create_http_context(&self, context_id: u32) -> Option<Box<dyn HttpContext>> {
debug!(
"||| create_http_context called with context_id: {:?} |||",
context_id
);
Some(Box::new(StreamContext::new(
context_id,
Rc::clone(&self.metrics),

View file

@ -282,10 +282,11 @@ impl StreamContext {
+ pred_class_desc_emb_similarity * 0.3;
debug!(
"similarity score: {:.3}, intent score: {:.3}, description embedding score: {:.3}",
"similarity score: {:.3}, intent score: {:.3}, description embedding score: {:.3}, prompt: {}",
prompt_target_similarity_score,
zeroshot_intent_response.predicted_class_score,
pred_class_desc_emb_similarity
pred_class_desc_emb_similarity,
callout_context.user_message.as_ref().unwrap()
);
let prompt_target_name = zeroshot_intent_response.predicted_class.clone();
@ -467,51 +468,7 @@ impl StreamContext {
}
};
// prompt target
let prompt_target = self
.prompt_targets
.read()
.unwrap()
.get(callout_context.prompt_target_name.as_ref().unwrap())
.unwrap()
.clone();
// // verify required parameters are present
// prompt_target
// .parameters
// .as_ref()
// .unwrap()
// .iter()
// .for_each(|param| match param.required {
// None => {}
// Some(required) => {
// if required
// && !tools_call_response.tool_calls[0]
// .arguments
// .contains_key(&param.name)
// {
// self.send_server_error(
// format!(
// "missing required parameter: {}, for target: {}",
// param.name, prompt_target.name
// ),
// Some(StatusCode::BAD_REQUEST),
// )
// }
// }
// });
debug!("tool_call_details: {:?}", tools_call_response);
let tool_name = &tools_call_response.tool_calls[0].name;
// ensure that detected tool name matches the prompt target name
if tool_name != &prompt_target.name {
warn!(
"tool name mismatch: detected tool name: {}, expected tool name: {}",
tool_name, &prompt_target.name
);
}
// extract all tool names
let tool_names: Vec<String> = tools_call_response
.tool_calls
@ -519,12 +476,26 @@ impl StreamContext {
.map(|tool_call| tool_call.name.clone())
.collect();
debug!(
"call context similarity score: {:?}",
callout_context.similarity_scores
);
//HACK: for now we only support one tool call, we will support multiple tool calls in the future
let tool_params = &tools_call_response.tool_calls[0].arguments;
let tools_call_name = tools_call_response.tool_calls[0].name.clone();
let tool_params_json_str = serde_json::to_string(&tool_params).unwrap();
let prompt_target = self
.prompt_targets
.read()
.unwrap()
.get(&tools_call_name)
.unwrap()
.clone();
debug!("prompt_target_name: {}", prompt_target.name);
debug!("tool_name(s): {:?}", tool_names);
debug!("tool_params: {}", tool_params_json_str);
debug!("prompt_target_name: {}", prompt_target.name);
let endpoint = prompt_target.endpoint.as_ref().unwrap();
let token_id = match self.dispatch_http_call(
@ -554,6 +525,19 @@ impl StreamContext {
}
fn function_call_response_handler(&mut self, body: Vec<u8>, callout_context: CallContext) {
let headers = self.get_http_call_response_headers();
debug!("response headers: {:?}", headers);
if let Some(http_status) = headers.iter().find(|(key, _)| key == ":status") {
if http_status.1 != StatusCode::OK.as_str() {
let error_msg = format!(
"Error in function call response: status code: {}",
http_status.1
);
return self.send_server_error(error_msg, Some(StatusCode::BAD_REQUEST));
}
} else {
warn!("http status code not found in api response");
}
debug!("response received for function call response");
let body_str: String = String::from_utf8(body).unwrap();
debug!("function_call_response response str: {:?}", body_str);