mirror of
https://github.com/katanemo/plano.git
synced 2026-05-18 13:45:15 +02:00
send all tools when sending request to arch-fc (#59)
This commit is contained in:
parent
3135ba8eae
commit
215d276acf
3 changed files with 93 additions and 46 deletions
|
|
@ -41,7 +41,7 @@ enum ResponseHandlerType {
|
|||
pub struct CallContext {
|
||||
response_handler_type: ResponseHandlerType,
|
||||
user_message: Option<String>,
|
||||
prompt_target: Option<PromptTarget>,
|
||||
prompt_target_name: Option<String>,
|
||||
request_body: ChatCompletionsRequest,
|
||||
similarity_scores: Option<Vec<(String, f64)>>,
|
||||
}
|
||||
|
|
@ -325,46 +325,47 @@ impl StreamContext {
|
|||
.unwrap()
|
||||
.clone();
|
||||
|
||||
info!(
|
||||
"prompt_target name: {:?}, type: {:?}",
|
||||
prompt_target.name, prompt_target.prompt_type
|
||||
);
|
||||
info!("prompt_target name: {:?}", prompt_target_name);
|
||||
|
||||
match prompt_target.prompt_type {
|
||||
PromptType::FunctionResolver => {
|
||||
// only extract entity names
|
||||
let properties: HashMap<String, ToolParameter> = match prompt_target.parameters {
|
||||
// Clone is unavoidable here because we don't want to move the values out of the prompt target struct.
|
||||
Some(ref entities) => {
|
||||
let mut properties: HashMap<String, ToolParameter> = HashMap::new();
|
||||
for entity in entities.iter() {
|
||||
let param = ToolParameter {
|
||||
parameter_type: entity.parameter_type.clone(),
|
||||
description: entity.description.clone(),
|
||||
required: entity.required,
|
||||
enum_values: entity.enum_values.clone(),
|
||||
};
|
||||
properties.insert(entity.name.clone(), param);
|
||||
}
|
||||
properties
|
||||
}
|
||||
None => HashMap::new(),
|
||||
};
|
||||
let tools_parameters = ToolParameters {
|
||||
parameters_type: "dict".to_string(),
|
||||
properties,
|
||||
};
|
||||
let mut tools_definitions: Vec<ToolsDefinition> = Vec::new();
|
||||
|
||||
let tools_defintion: ToolsDefinition = ToolsDefinition {
|
||||
name: prompt_target.name.clone(),
|
||||
description: prompt_target.description.clone(),
|
||||
parameters: tools_parameters,
|
||||
};
|
||||
for pt in self.prompt_targets.read().unwrap().values() {
|
||||
// only extract entity names
|
||||
let properties: HashMap<String, ToolParameter> = match pt.parameters {
|
||||
// Clone is unavoidable here because we don't want to move the values out of the prompt target struct.
|
||||
Some(ref entities) => {
|
||||
let mut properties: HashMap<String, ToolParameter> = HashMap::new();
|
||||
for entity in entities.iter() {
|
||||
let param = ToolParameter {
|
||||
parameter_type: entity.parameter_type.clone(),
|
||||
description: entity.description.clone(),
|
||||
required: entity.required,
|
||||
enum_values: entity.enum_values.clone(),
|
||||
};
|
||||
properties.insert(entity.name.clone(), param);
|
||||
}
|
||||
properties
|
||||
}
|
||||
None => HashMap::new(),
|
||||
};
|
||||
let tools_parameters = ToolParameters {
|
||||
parameters_type: "dict".to_string(),
|
||||
properties,
|
||||
};
|
||||
|
||||
tools_definitions.push(ToolsDefinition {
|
||||
name: pt.name.clone(),
|
||||
description: pt.description.clone(),
|
||||
parameters: tools_parameters,
|
||||
});
|
||||
}
|
||||
|
||||
let chat_completions = ChatCompletionsRequest {
|
||||
model: GPT_35_TURBO.to_string(),
|
||||
messages: callout_context.request_body.messages.clone(),
|
||||
tools: Some(vec![tools_defintion]),
|
||||
tools: Some(tools_definitions),
|
||||
stream: false,
|
||||
stream_options: None,
|
||||
};
|
||||
|
|
@ -411,7 +412,7 @@ impl StreamContext {
|
|||
);
|
||||
|
||||
callout_context.response_handler_type = ResponseHandlerType::FunctionResolver;
|
||||
callout_context.prompt_target = Some(prompt_target);
|
||||
callout_context.prompt_target_name = Some(prompt_target.name);
|
||||
if self.callouts.insert(token_id, callout_context).is_some() {
|
||||
panic!("duplicate token_id")
|
||||
}
|
||||
|
|
@ -438,7 +439,7 @@ impl StreamContext {
|
|||
// Let's send the response back to the user to initalize lightweight dialog for parameter collection
|
||||
|
||||
// add resolver name to the response so the client can send the response back to the correct resolver
|
||||
boltfc_response.resolver_name = Some(callout_context.prompt_target.unwrap().name);
|
||||
boltfc_response.resolver_name = Some(callout_context.prompt_target_name.unwrap());
|
||||
info!("some requred parameters are missing, sending response from Bolt FC back to user for parameter collection: {}", e);
|
||||
let bolt_fc_dialogue_message = serde_json::to_string(&boltfc_response).unwrap();
|
||||
self.send_http_response(
|
||||
|
|
@ -450,11 +451,18 @@ impl StreamContext {
|
|||
}
|
||||
};
|
||||
|
||||
// verify required parameters are present
|
||||
callout_context
|
||||
.prompt_target
|
||||
.as_ref()
|
||||
// prompt target
|
||||
|
||||
let prompt_target = self
|
||||
.prompt_targets
|
||||
.read()
|
||||
.unwrap()
|
||||
.get(callout_context.prompt_target_name.as_ref().unwrap())
|
||||
.unwrap()
|
||||
.clone();
|
||||
// verify required parameters are present
|
||||
|
||||
prompt_target
|
||||
.parameters
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
|
|
@ -477,10 +485,17 @@ impl StreamContext {
|
|||
|
||||
debug!("tool_call_details: {:?}", tools_call_response);
|
||||
let tool_name = &tools_call_response.tool_calls[0].name;
|
||||
|
||||
// ensure that detected tool name matches the prompt target name
|
||||
if tool_name != &prompt_target.name {
|
||||
warn!(
|
||||
"tool name mismatch: detected tool name: {}, expected tool name: {}",
|
||||
tool_name, &prompt_target.name
|
||||
);
|
||||
}
|
||||
let tool_params = &tools_call_response.tool_calls[0].arguments;
|
||||
debug!("tool_name: {:?}", tool_name);
|
||||
debug!("tool_params: {:?}", tool_params);
|
||||
let prompt_target = callout_context.prompt_target.as_ref().unwrap();
|
||||
debug!("prompt_target: {:?}", prompt_target);
|
||||
|
||||
let tool_params_json_str = serde_json::to_string(&tool_params).unwrap();
|
||||
|
|
@ -516,7 +531,14 @@ impl StreamContext {
|
|||
debug!("response received for function call response");
|
||||
let body_str: String = String::from_utf8(body).unwrap();
|
||||
debug!("function_call_response response str: {:?}", body_str);
|
||||
let prompt_target = callout_context.prompt_target.as_ref().unwrap();
|
||||
let prompt_target_name = callout_context.prompt_target_name.unwrap();
|
||||
let prompt_target = self
|
||||
.prompt_targets
|
||||
.read()
|
||||
.unwrap()
|
||||
.get(&prompt_target_name)
|
||||
.unwrap()
|
||||
.clone();
|
||||
|
||||
let mut messages: Vec<Message> = callout_context.request_body.messages.clone();
|
||||
|
||||
|
|
@ -714,7 +736,7 @@ impl HttpContext for StreamContext {
|
|||
let call_context = CallContext {
|
||||
response_handler_type: ResponseHandlerType::GetEmbeddings,
|
||||
user_message: Some(user_message),
|
||||
prompt_target: None,
|
||||
prompt_target_name: None,
|
||||
request_body: deserialized_body,
|
||||
similarity_scores: None,
|
||||
};
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue