send all tools when sending request to arch-fc (#59)

This commit is contained in:
Adil Hafeez 2024-09-18 15:54:40 -07:00 committed by GitHub
parent 3135ba8eae
commit 215d276acf
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 93 additions and 46 deletions

View file

@ -41,7 +41,7 @@ enum ResponseHandlerType {
pub struct CallContext {
response_handler_type: ResponseHandlerType,
user_message: Option<String>,
prompt_target: Option<PromptTarget>,
prompt_target_name: Option<String>,
request_body: ChatCompletionsRequest,
similarity_scores: Option<Vec<(String, f64)>>,
}
@ -325,46 +325,47 @@ impl StreamContext {
.unwrap()
.clone();
info!(
"prompt_target name: {:?}, type: {:?}",
prompt_target.name, prompt_target.prompt_type
);
info!("prompt_target name: {:?}", prompt_target_name);
match prompt_target.prompt_type {
PromptType::FunctionResolver => {
// only extract entity names
let properties: HashMap<String, ToolParameter> = match prompt_target.parameters {
// Clone is unavoidable here because we don't want to move the values out of the prompt target struct.
Some(ref entities) => {
let mut properties: HashMap<String, ToolParameter> = HashMap::new();
for entity in entities.iter() {
let param = ToolParameter {
parameter_type: entity.parameter_type.clone(),
description: entity.description.clone(),
required: entity.required,
enum_values: entity.enum_values.clone(),
};
properties.insert(entity.name.clone(), param);
}
properties
}
None => HashMap::new(),
};
let tools_parameters = ToolParameters {
parameters_type: "dict".to_string(),
properties,
};
let mut tools_definitions: Vec<ToolsDefinition> = Vec::new();
let tools_defintion: ToolsDefinition = ToolsDefinition {
name: prompt_target.name.clone(),
description: prompt_target.description.clone(),
parameters: tools_parameters,
};
for pt in self.prompt_targets.read().unwrap().values() {
// only extract entity names
let properties: HashMap<String, ToolParameter> = match pt.parameters {
// Clone is unavoidable here because we don't want to move the values out of the prompt target struct.
Some(ref entities) => {
let mut properties: HashMap<String, ToolParameter> = HashMap::new();
for entity in entities.iter() {
let param = ToolParameter {
parameter_type: entity.parameter_type.clone(),
description: entity.description.clone(),
required: entity.required,
enum_values: entity.enum_values.clone(),
};
properties.insert(entity.name.clone(), param);
}
properties
}
None => HashMap::new(),
};
let tools_parameters = ToolParameters {
parameters_type: "dict".to_string(),
properties,
};
tools_definitions.push(ToolsDefinition {
name: pt.name.clone(),
description: pt.description.clone(),
parameters: tools_parameters,
});
}
let chat_completions = ChatCompletionsRequest {
model: GPT_35_TURBO.to_string(),
messages: callout_context.request_body.messages.clone(),
tools: Some(vec![tools_defintion]),
tools: Some(tools_definitions),
stream: false,
stream_options: None,
};
@ -411,7 +412,7 @@ impl StreamContext {
);
callout_context.response_handler_type = ResponseHandlerType::FunctionResolver;
callout_context.prompt_target = Some(prompt_target);
callout_context.prompt_target_name = Some(prompt_target.name);
if self.callouts.insert(token_id, callout_context).is_some() {
panic!("duplicate token_id")
}
@ -438,7 +439,7 @@ impl StreamContext {
// Let's send the response back to the user to initalize lightweight dialog for parameter collection
// add resolver name to the response so the client can send the response back to the correct resolver
boltfc_response.resolver_name = Some(callout_context.prompt_target.unwrap().name);
boltfc_response.resolver_name = Some(callout_context.prompt_target_name.unwrap());
info!("some requred parameters are missing, sending response from Bolt FC back to user for parameter collection: {}", e);
let bolt_fc_dialogue_message = serde_json::to_string(&boltfc_response).unwrap();
self.send_http_response(
@ -450,11 +451,18 @@ impl StreamContext {
}
};
// verify required parameters are present
callout_context
.prompt_target
.as_ref()
// prompt target
let prompt_target = self
.prompt_targets
.read()
.unwrap()
.get(callout_context.prompt_target_name.as_ref().unwrap())
.unwrap()
.clone();
// verify required parameters are present
prompt_target
.parameters
.as_ref()
.unwrap()
@ -477,10 +485,17 @@ impl StreamContext {
debug!("tool_call_details: {:?}", tools_call_response);
let tool_name = &tools_call_response.tool_calls[0].name;
// ensure that detected tool name matches the prompt target name
if tool_name != &prompt_target.name {
warn!(
"tool name mismatch: detected tool name: {}, expected tool name: {}",
tool_name, &prompt_target.name
);
}
let tool_params = &tools_call_response.tool_calls[0].arguments;
debug!("tool_name: {:?}", tool_name);
debug!("tool_params: {:?}", tool_params);
let prompt_target = callout_context.prompt_target.as_ref().unwrap();
debug!("prompt_target: {:?}", prompt_target);
let tool_params_json_str = serde_json::to_string(&tool_params).unwrap();
@ -516,7 +531,14 @@ impl StreamContext {
debug!("response received for function call response");
let body_str: String = String::from_utf8(body).unwrap();
debug!("function_call_response response str: {:?}", body_str);
let prompt_target = callout_context.prompt_target.as_ref().unwrap();
let prompt_target_name = callout_context.prompt_target_name.unwrap();
let prompt_target = self
.prompt_targets
.read()
.unwrap()
.get(&prompt_target_name)
.unwrap()
.clone();
let mut messages: Vec<Message> = callout_context.request_body.messages.clone();
@ -714,7 +736,7 @@ impl HttpContext for StreamContext {
let call_context = CallContext {
response_handler_type: ResponseHandlerType::GetEmbeddings,
user_message: Some(user_message),
prompt_target: None,
prompt_target_name: None,
request_body: deserialized_body,
similarity_scores: None,
};