mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
fixed JSON parsing issues in function_calling.rs
This commit is contained in:
parent
9922fe0cb9
commit
e7918d0caf
12 changed files with 38 additions and 953 deletions
|
|
@ -98,17 +98,10 @@ pub struct ArchFunctionConfig {
|
|||
impl Default for ArchFunctionConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
task_prompt: concat!(
|
||||
"You are a helpful assistant designed to assist with the user query by making one or more function calls if needed.",
|
||||
"\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>\n{tools}\n</tools>",
|
||||
"\n\nYour task is to decide which functions are needed and collect missing parameters if necessary."
|
||||
).to_string(),
|
||||
format_prompt: concat!(
|
||||
"\n\nBased on your analysis, provide your response in one of the following JSON formats:",
|
||||
"\n1. If no functions are needed:\n```json\n{\"response\": \"Your response text here\"}\n```",
|
||||
"\n2. If functions are needed but some required parameters are missing:\n```json\n{\"required_functions\": [\"func_name1\", \"func_name2\", ...], \"clarification\": \"Text asking for missing parameters\"}\n```",
|
||||
"\n3. If functions are needed and all required parameters are available:\n```json\n{\"tool_calls\": [{\"name\": \"func_name1\", \"arguments\": {\"argument1\": \"value1\", \"argument2\": \"value2\"}},... (more tool calls as required)]}\n```"
|
||||
).to_string(),
|
||||
// Raw string so that \n sequences remain literal in the final prompt
|
||||
task_prompt: r#"You are a helpful assistant designed to assist with the user query by making one or more function calls if needed.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>\n{tools}\n</tools>\n\nYour task is to decide which functions are needed and collect missing parameters if necessary."#.to_string(),
|
||||
// Use raw string to preserve literal \n sequences instead of real newlines
|
||||
format_prompt: r#"\n\nBased on your analysis, provide your response in one of the following JSON formats:\n1. If no functions are needed:\n```json\n{\"response\": \"Your response text here\"}\n```\n2. If functions are needed but some required parameters are missing:\n```json\n{\"required_functions\": [\"func_name1\", \"func_name2\", ...], \"clarification\": \"Text asking for missing parameters\"}\n```\n3. If functions are needed and all required parameters are available:\n```json\n{\"tool_calls\": [{\"name\": \"func_name1\", \"arguments\": {\"argument1\": \"value1\", \"argument2\": \"value2\"}},... (more tool calls as required)]}\n```"#.to_string(),
|
||||
generation_params: GenerationParams::default(),
|
||||
support_data_types: vec![
|
||||
"int".to_string(),
|
||||
|
|
@ -255,8 +248,8 @@ impl ArchFunctionHandler {
|
|||
Self {
|
||||
model_name,
|
||||
config,
|
||||
default_prefix: "```json\n{\"".to_string(),
|
||||
clarify_prefix: "```json\n{\"required_functions\":".to_string(),
|
||||
default_prefix: r#"```json\n{\""#.to_string(),
|
||||
clarify_prefix: r#"```json\n{\"required_functions\":"#.to_string(),
|
||||
endpoint_url,
|
||||
http_client,
|
||||
}
|
||||
|
|
@ -270,7 +263,7 @@ impl ArchFunctionHandler {
|
|||
.collect();
|
||||
|
||||
converted
|
||||
.map(|v| v.join("\n"))
|
||||
.map(|v| v.join("\\n"))
|
||||
.map_err(FunctionCallingError::from)
|
||||
}
|
||||
|
||||
|
|
@ -336,19 +329,37 @@ impl ArchFunctionHandler {
|
|||
pub fn parse_model_response(&self, content: &str) -> ParsedModelResponse {
|
||||
let mut response_dict = ParsedModelResponse::default();
|
||||
|
||||
// Store original content for raw_response before any processing
|
||||
let original_content = content.trim().to_string();
|
||||
|
||||
// Remove markdown code blocks
|
||||
let mut content = content.trim().to_string();
|
||||
if content.starts_with("```") && content.ends_with("```") {
|
||||
content = content.trim_start_matches("```").trim_end_matches("```").trim().to_string();
|
||||
content = content.trim_start_matches("```").trim_end_matches("```").to_string();
|
||||
if content.starts_with("json") {
|
||||
content = content.trim_start_matches("json").trim().to_string();
|
||||
content = content.trim_start_matches("json").to_string();
|
||||
}
|
||||
// Trim again after removing code blocks to eliminate internal whitespace
|
||||
content = content.trim_start_matches(r"\n").trim_end_matches(r"\n").to_string();
|
||||
content = content.trim().to_string();
|
||||
// Unescape the quotes: \" -> "
|
||||
// The model sometimes returns escaped JSON inside markdown blocks
|
||||
content = content.replace(r#"\""#, "\"");
|
||||
}
|
||||
|
||||
// Try to fix JSON if needed
|
||||
let fixed_content = match self.fix_json_string(&content) {
|
||||
Ok(fixed) => {
|
||||
response_dict.raw_response = format!("```json\n{}\n```", fixed);
|
||||
// Build raw_response with literal \n sequences (not actual newlines)
|
||||
// Replace actual newlines with \n string literal
|
||||
let sanitized_content = if original_content.starts_with("```") {
|
||||
// Original had code blocks - replace actual newlines with \n literals
|
||||
original_content.replace('\n', r"\n")
|
||||
} else {
|
||||
// Wrap the content with literal \n sequences
|
||||
format!(r"```json\n{}\n```", content)
|
||||
};
|
||||
response_dict.raw_response = sanitized_content;
|
||||
fixed
|
||||
}
|
||||
Err(e) => {
|
||||
|
|
@ -357,12 +368,12 @@ impl ArchFunctionHandler {
|
|||
return response_dict;
|
||||
}
|
||||
};
|
||||
|
||||
// Parse the JSON
|
||||
match serde_json::from_str::<Value>(&fixed_content) {
|
||||
Ok(model_response) => {
|
||||
// Successfully parsed - mark as valid
|
||||
response_dict.is_valid = true;
|
||||
info!("c8: {:?}", model_response);
|
||||
|
||||
// Extract response field
|
||||
if let Some(resp) = model_response.get("response") {
|
||||
|
|
@ -600,12 +611,9 @@ impl ArchFunctionHandler {
|
|||
/// Formats the system prompt with tools
|
||||
pub fn format_system_prompt(&self, tools: &[Tool]) -> Result<String> {
|
||||
let tools_str = self.convert_tools(tools)?;
|
||||
let today_date = chrono::Local::now().format("%Y-%m-%d").to_string();
|
||||
|
||||
let system_prompt = self
|
||||
.config
|
||||
.task_prompt
|
||||
.replace("{today_date}", &today_date)
|
||||
.replace("{tools}", &tools_str)
|
||||
+ &self.config.format_prompt;
|
||||
|
||||
|
|
@ -1369,15 +1377,15 @@ mod tests {
|
|||
assert_eq!(config.generation_params.temperature, 0.1);
|
||||
assert_eq!(config.support_data_types.len(), 14); // 8 Python-style + 6 JSON Schema names
|
||||
|
||||
// Verify exact prompt formatting matches Python
|
||||
// Task prompt should have actual newlines, not escaped strings
|
||||
assert!(config.task_prompt.contains("\n\nYou are provided"));
|
||||
assert!(config.task_prompt.contains("</tools>\n\n"));
|
||||
// Verify prompt formatting for literal escaped newlines ("\\n") instead of actual newline chars
|
||||
// The user requirement changed prompts to display "\\n" sequences literally.
|
||||
assert!(config.task_prompt.contains("\\n\\nYou are provided"));
|
||||
assert!(config.task_prompt.contains("</tools>\\n\\n"));
|
||||
|
||||
// Format prompt should have actual newlines and proper JSON with escaped quotes
|
||||
assert!(config.format_prompt.contains("\n\nBased on your analysis"));
|
||||
assert!(config.format_prompt.contains(r#"{"response": "Your response text here"}"#));
|
||||
assert!(config.format_prompt.contains(r#"{"tool_calls": [{"#));
|
||||
// Format prompt should contain literal escaped newlines and proper JSON examples
|
||||
assert!(config.format_prompt.contains("\\n\\nBased on your analysis"));
|
||||
assert!(config.format_prompt.contains(r#"{\"response\": \"Your response text here\"}"#));
|
||||
assert!(config.format_prompt.contains(r#"{\"tool_calls\": [{"#));
|
||||
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue