use bytes::Bytes; use eventsource_stream::Eventsource; use futures::StreamExt; use hermesllm::apis::openai::{ ChatCompletionsRequest, ChatCompletionsResponse, Choice, FinishReason, FunctionCall, Message, MessageContent, ResponseMessage, Role, Tool, ToolCall, Usage, }; use http_body_util::{combinators::BoxBody, BodyExt, Full}; use hyper::body::Incoming; use hyper::{Request, Response, StatusCode}; use serde::{Deserialize, Serialize}; use serde_json::{json, Value}; use std::collections::HashMap; use thiserror::Error; use tracing::{error, info}; // ============================================================================ // CONSTANTS FOR HALLUCINATION DETECTION // ============================================================================ const FUNC_NAME_START_PATTERN: &[&str] = &[r#"{"name":""#, r#"{'name':'"#]; const FUNC_NAME_END_TOKEN: &[&str] = &["\",", "',"]; const END_TOOL_CALL_TOKEN: &str = "}}"; const FIRST_PARAM_NAME_START_PATTERN: &[&str] = &[r#""arguments":{"#, r#"'arguments':{'"#]; const PARAMETER_NAME_END_TOKENS: &[&str] = &["\":", ":\"", "':", ":'", "\":\"", "':'"]; const PARAMETER_NAME_START_PATTERN: &[&str] = &["\",\"", "','"]; const PARAMETER_VALUE_START_PATTERN: &[&str] = &["\":", "':"]; const PARAMETER_VALUE_END_TOKEN: &[&str] = &["\",", "\"}"]; const ARCH_FUNCTION_MODEL_NAME: &str = "Arch-Function"; /// Default hallucination detection thresholds #[derive(Debug, Clone)] pub struct HallucinationThresholds { pub entropy: f64, pub varentropy: f64, pub probability: f64, } impl Default for HallucinationThresholds { fn default() -> Self { Self { entropy: 0.0001, varentropy: 0.0001, probability: 0.8, } } } // ============================================================================ // ERROR TYPES // ============================================================================ #[derive(Debug, Error)] pub enum FunctionCallingError { #[error("Failed to parse JSON: {0}")] JsonParseError(#[from] serde_json::Error), #[error("Failed to fix malformed JSON: {0}")] JsonFixError(String), #[error("Invalid model response: {0}")] InvalidModelResponse(String), #[error("Tool call verification failed: {0}")] ToolCallVerificationError(String), #[error("Data type conversion error: {0}")] DataTypeConversionError(String), #[error("Unsupported data type: {0}")] UnsupportedDataType(String), #[error("HTTP request error: {0}")] HttpError(#[from] reqwest::Error), #[error("Invalid tool call: {0}")] InvalidToolCall(String), } pub type Result = std::result::Result; // ============================================================================ // CONFIGURATION STRUCTURES // ============================================================================ /// Configuration for Arch Function Calling #[derive(Debug, Clone)] pub struct ArchFunctionConfig { pub task_prompt: String, pub format_prompt: String, pub generation_params: GenerationParams, pub support_data_types: Vec, } impl Default for ArchFunctionConfig { fn default() -> Self { Self { // Raw string so that \n sequences remain literal in the final prompt task_prompt: r#"You are a helpful assistant designed to assist with the user query by making one or more function calls if needed.\n\nYou are provided with function signatures within XML tags:\n\n{tools}\n\n\nYour task is to decide which functions are needed and collect missing parameters if necessary."#.to_string(), // Use raw string to preserve literal \n sequences instead of real newlines format_prompt: r#"\n\nBased on your analysis, provide your response in one of the following JSON formats:\n1. If no functions are needed:\n```json\n{\"response\": \"Your response text here\"}\n```\n2. If functions are needed but some required parameters are missing:\n```json\n{\"required_functions\": [\"func_name1\", \"func_name2\", ...], \"clarification\": \"Text asking for missing parameters\"}\n```\n3. If functions are needed and all required parameters are available:\n```json\n{\"tool_calls\": [{\"name\": \"func_name1\", \"arguments\": {\"argument1\": \"value1\", \"argument2\": \"value2\"}},... (more tool calls as required)]}\n```"#.to_string(), generation_params: GenerationParams::default(), support_data_types: vec![ "int".to_string(), "float".to_string(), "bool".to_string(), "str".to_string(), "list".to_string(), "tuple".to_string(), "set".to_string(), "dict".to_string(), // JSON Schema names (standard) "integer".to_string(), "number".to_string(), "boolean".to_string(), "string".to_string(), "array".to_string(), "object".to_string(), ], } } } /// Configuration for Arch Agent (extends ArchFunctionConfig with different generation params) #[derive(Debug, Clone)] pub struct ArchAgentConfig { pub task_prompt: String, pub format_prompt: String, pub generation_params: GenerationParams, pub support_data_types: Vec, } impl Default for ArchAgentConfig { fn default() -> Self { let base = ArchFunctionConfig::default(); Self { task_prompt: base.task_prompt, format_prompt: base.format_prompt, generation_params: GenerationParams { temperature: 0.01, top_p: 1.0, top_k: 10, max_tokens: 1024, stop_token_ids: vec![151645], logprobs: Some(true), top_logprobs: Some(10), }, support_data_types: base.support_data_types, } } } /// Generation parameters for LLM #[derive(Debug, Clone, Serialize, Deserialize)] pub struct GenerationParams { pub temperature: f32, pub top_p: f32, pub top_k: u32, pub max_tokens: u32, pub stop_token_ids: Vec, pub logprobs: Option, pub top_logprobs: Option, } impl Default for GenerationParams { fn default() -> Self { Self { temperature: 0.1, top_p: 1.0, top_k: 10, max_tokens: 1024, stop_token_ids: vec![151645], logprobs: Some(true), top_logprobs: Some(10), } } } // ============================================================================ // PARSED MODEL RESPONSE // ============================================================================ /// Parsed response from the model #[derive(Debug, Clone, Serialize, Deserialize, Default)] pub struct ParsedModelResponse { pub raw_response: String, pub response: Option, pub required_functions: Vec, pub clarification: String, pub tool_calls: Vec, pub is_valid: bool, pub error_message: String, } // ============================================================================ // TOOL CALL VERIFICATION RESULT // ============================================================================ /// Result of tool call verification #[derive(Debug, Clone)] pub struct ToolCallVerification { pub is_valid: bool, pub invalid_tool_call: Option, pub error_message: String, } impl Default for ToolCallVerification { fn default() -> Self { Self { is_valid: true, invalid_tool_call: None, error_message: String::new(), } } } /// Main handler for Arch Function Calling pub struct ArchFunctionHandler { pub model_name: String, pub config: ArchFunctionConfig, pub default_prefix: String, pub clarify_prefix: String, pub endpoint_url: String, pub http_client: reqwest::Client, } impl ArchFunctionHandler { /// Creates a new ArchFunctionHandler pub fn new(model_name: String, config: ArchFunctionConfig, endpoint_url: String) -> Self { use common::consts::ARCH_PROVIDER_HINT_HEADER; use reqwest::header; // Create custom HTTP client with Arch provider hint header let mut headers = header::HeaderMap::new(); headers.insert( header::HeaderName::from_static(ARCH_PROVIDER_HINT_HEADER), header::HeaderValue::from_str(&model_name).unwrap(), ); let http_client = reqwest::ClientBuilder::new() .default_headers(headers) .build() .expect("Failed to create HTTP client"); Self { model_name, config, default_prefix: r#"```json\n{\""#.to_string(), clarify_prefix: r#"```json\n{\"required_functions\":"#.to_string(), endpoint_url, http_client, } } /// Converts a list of tools into JSON format string pub fn convert_tools(&self, tools: &[Tool]) -> Result { let converted: std::result::Result, serde_json::Error> = tools .iter() .map(|tool| serde_json::to_string(&tool.function)) .collect(); converted .map(|v| v.join("\\n")) .map_err(FunctionCallingError::from) } /// Fixes malformed JSON strings by ensuring proper bracket matching pub fn fix_json_string(&self, json_str: &str) -> Result { let json_str = json_str.trim(); let mut stack: Vec = Vec::new(); let mut fixed_str = String::new(); let matching_bracket: HashMap = [(')', '('), ('}', '{'), (']', '[')] .iter() .cloned() .collect(); let opening_bracket: HashMap = matching_bracket.iter().map(|(k, v)| (*v, *k)).collect(); for ch in json_str.chars() { if ch == '{' || ch == '[' || ch == '(' { stack.push(ch); fixed_str.push(ch); } else if ch == '}' || ch == ']' || ch == ')' { if let Some(&last) = stack.last() { if matching_bracket.get(&ch) == Some(&last) { stack.pop(); fixed_str.push(ch); } // Ignore unmatched closing brackets } } else { fixed_str.push(ch); } } // Add corresponding closing brackets for unmatched opening brackets while let Some(unmatched_opening) = stack.pop() { if let Some(&closing) = opening_bracket.get(&unmatched_opening) { fixed_str.push(closing); } } // Try to parse the fixed JSON match serde_json::from_str::(&fixed_str) { Ok(val) => serde_json::to_string(&val).map_err(FunctionCallingError::from), Err(_) => { // Try replacing single quotes with double quotes let fixed_str = fixed_str.replace('\'', "\""); match serde_json::from_str::(&fixed_str) { Ok(val) => serde_json::to_string(&val).map_err(FunctionCallingError::from), Err(e) => Err(FunctionCallingError::JsonFixError(format!( "Failed to fix JSON: {}", e ))), } } } } /// Parses the model response and extracts tool call information pub fn parse_model_response(&self, content: &str) -> ParsedModelResponse { let mut response_dict = ParsedModelResponse::default(); // Remove markdown code blocks let mut content = content.trim().to_string(); if content.starts_with("```") && content.ends_with("```") { content = content .trim_start_matches("```") .trim_end_matches("```") .to_string(); if content.starts_with("json") { content = content.trim_start_matches("json").to_string(); } // Trim again after removing code blocks to eliminate internal whitespace content = content .trim_start_matches(r"\n") .trim_end_matches(r"\n") .to_string(); content = content.trim().to_string(); // Unescape the quotes: \" -> " // The model sometimes returns escaped JSON inside markdown blocks content = content.replace(r#"\""#, "\""); } // Try to fix JSON if needed let fixed_content = match self.fix_json_string(&content) { Ok(fixed) => { response_dict.raw_response = format!("```json\n{}\n```", fixed); fixed } Err(e) => { response_dict.is_valid = false; response_dict.error_message = format!("Failed to fix JSON: {}", e); return response_dict; } }; // Parse the JSON match serde_json::from_str::(&fixed_content) { Ok(model_response) => { // Successfully parsed - mark as valid response_dict.is_valid = true; // Extract response field if let Some(resp) = model_response.get("response") { if let Some(resp_str) = resp.as_str() { response_dict.response = Some(resp_str.to_string()); } } // Extract required_functions if let Some(funcs) = model_response.get("required_functions") { if let Some(funcs_arr) = funcs.as_array() { response_dict.required_functions = funcs_arr .iter() .filter_map(|v| v.as_str().map(String::from)) .collect(); } } // Extract clarification if let Some(clarif) = model_response.get("clarification") { if let Some(clarif_str) = clarif.as_str() { response_dict.clarification = clarif_str.to_string(); } } // Extract tool_calls if let Some(tool_calls) = model_response.get("tool_calls") { if let Some(tool_calls_arr) = tool_calls.as_array() { for tool_call_val in tool_calls_arr { let id = format!("call_{}", rand::random::() % 10000 + 1000); let name = tool_call_val .get("name") .and_then(|v| v.as_str()) .unwrap_or("") .to_string(); let arguments = tool_call_val .get("arguments") .map(|v| serde_json::to_string(v).unwrap_or_default()) .unwrap_or_default(); response_dict.tool_calls.push(ToolCall { id, call_type: "function".to_string(), function: FunctionCall { name, arguments }, }); } } } } Err(e) => { response_dict.is_valid = false; response_dict.error_message = format!("Failed to parse model response: {}", e); } } response_dict } /// Converts data type from one type to another pub fn convert_data_type(&self, value: &Value, target_type: &str) -> Result { match target_type { // Handle float/number conversions "float" | "number" => { if let Some(int_val) = value.as_i64() { return Ok(json!(int_val as f64)); } } // Handle list/array conversions "list" | "array" => { if let Some(str_val) = value.as_str() { // Try to parse as JSON array if let Ok(arr) = serde_json::from_str::>(str_val) { return Ok(json!(arr)); } } } // Handle str/string conversions "str" | "string" => { if !value.is_string() { return Ok(json!(value.to_string())); } } _ => {} } Ok(value.clone()) } /// Helper method to check if a value matches the expected type fn check_value_type(&self, value: &Value, target_type: &str) -> bool { match target_type { "int" | "integer" => value.is_i64() || value.is_u64(), "float" | "number" => value.is_f64() || value.is_i64() || value.is_u64(), "bool" | "boolean" => value.is_boolean(), "str" | "string" => value.is_string(), "list" | "array" => value.is_array(), "dict" | "object" => value.is_object(), _ => true, } } /// Helper method to validate and potentially convert a parameter value to match the target type /// Returns Ok(true) if the value is valid (either originally or after conversion) /// Returns Ok(false) if the value cannot be converted to the target type fn validate_or_convert_parameter( &self, param_value: &Value, target_type: &str, ) -> Result { // First check: Is it already the correct type? if self.check_value_type(param_value, target_type) { return Ok(true); } // Try to convert let converted = self.convert_data_type(param_value, target_type)?; // Second check: Is it the correct type after conversion? Ok(self.check_value_type(&converted, target_type)) } /// Verifies the validity of extracted tool calls against the provided tools pub fn verify_tool_calls( &self, tools: &[Tool], tool_calls: &[ToolCall], ) -> ToolCallVerification { let mut verification = ToolCallVerification::default(); // Build a map of function name to parameters let mut functions: HashMap = HashMap::new(); for tool in tools { functions.insert(tool.function.name.clone(), &tool.function.parameters); } for tool_call in tool_calls { if !verification.is_valid { break; } let func_name = &tool_call.function.name; // Parse arguments as JSON let func_args: HashMap = match serde_json::from_str(&tool_call.function.arguments) { Ok(args) => args, Err(e) => { verification.is_valid = false; verification.invalid_tool_call = Some(tool_call.clone()); verification.error_message = format!( "Failed to parse arguments for function '{}': {}", func_name, e ); break; } }; // Check if function is available if let Some(function_params) = functions.get(func_name) { // Check if all required parameters are present if let Some(required) = function_params.get("required") { if let Some(required_arr) = required.as_array() { for required_param in required_arr { if let Some(param_name) = required_param.as_str() { if !func_args.contains_key(param_name) { verification.is_valid = false; verification.invalid_tool_call = Some(tool_call.clone()); verification.error_message = format!( "`{}` is required by the function `{}` but not found in the tool call!", param_name, func_name ); break; } } } } } // Verify the data type of each parameter if let Some(properties) = function_params.get("properties") { if let Some(properties_obj) = properties.as_object() { for (param_name, param_value) in &func_args { if let Some(param_schema) = properties_obj.get(param_name) { if let Some(target_type) = param_schema.get("type").and_then(|v| v.as_str()) { if self .config .support_data_types .contains(&target_type.to_string()) { // Validate data type using helper method match self .validate_or_convert_parameter(param_value, target_type) { Ok(is_valid) => { if !is_valid { verification.is_valid = false; verification.invalid_tool_call = Some(tool_call.clone()); verification.error_message = format!( "Parameter `{}` is expected to have the data type `{}`, got incompatible type.", param_name, target_type ); break; } } Err(_) => { verification.is_valid = false; verification.invalid_tool_call = Some(tool_call.clone()); verification.error_message = format!( "Parameter `{}` is expected to have the data type `{}`, got incompatible type.", param_name, target_type ); break; } } } else { verification.is_valid = false; verification.invalid_tool_call = Some(tool_call.clone()); verification.error_message = format!( "Data type `{}` is not supported.", target_type ); break; } } } else { verification.is_valid = false; verification.invalid_tool_call = Some(tool_call.clone()); verification.error_message = format!( "Parameter `{}` is not defined in the function `{}`.", param_name, func_name ); break; } } } } } else { verification.is_valid = false; verification.invalid_tool_call = Some(tool_call.clone()); verification.error_message = format!("{} is not available!", func_name); } } verification } /// Formats the system prompt with tools pub fn format_system_prompt(&self, tools: &[Tool]) -> Result { let tools_str = self.convert_tools(tools)?; let system_prompt = self.config.task_prompt.replace("{tools}", &tools_str) + &self.config.format_prompt; Ok(system_prompt) } /// Processes messages and formats them appropriately for the model pub fn process_messages( &self, messages: &[Message], tools: Option<&[Tool]>, extra_instruction: Option<&str>, max_tokens: usize, metadata: Option<&HashMap>, ) -> Result> { let mut processed_messages = Vec::new(); // Add system message with tools if provided if let Some(tools) = tools { let system_prompt = self.format_system_prompt(tools)?; processed_messages.push(Message { role: Role::System, content: MessageContent::Text(system_prompt), name: None, tool_calls: None, tool_call_id: None, }); } // Process each message for (idx, message) in messages.iter().enumerate() { let mut role = message.role.clone(); let mut content = match &message.content { MessageContent::Text(text) => text.clone(), MessageContent::Parts(_) => String::new(), }; // Handle tool calls if let Some(tool_calls) = &message.tool_calls { if !tool_calls.is_empty() { role = Role::Assistant; let tool_call_json = serde_json::to_string(&tool_calls[0].function)?; content = format!("\n{}\n", tool_call_json); } } else if role == Role::Tool { role = Role::User; // Check if we should optimize context window let optimize_context = metadata .and_then(|m| m.get("optimize_context_window")) .and_then(|v| v.as_str()) .map(|s| s.to_lowercase() == "true") .unwrap_or(false); if optimize_context { content = "\n\n".to_string(); } else { // Get the tool call from previous message if idx > 0 { if let MessageContent::Text(prev_content) = &messages[idx - 1].content { let mut tool_call_msg = prev_content.clone(); // Strip markdown code blocks if tool_call_msg.starts_with("```") && tool_call_msg.ends_with("```") { tool_call_msg = tool_call_msg .trim_start_matches("```") .trim_end_matches("```") .trim() .to_string(); if tool_call_msg.starts_with("json") { tool_call_msg = tool_call_msg.trim_start_matches("json").trim().to_string(); } } // Extract function name if let Ok(parsed) = serde_json::from_str::(&tool_call_msg) { if let Some(tool_calls_arr) = parsed.get("tool_calls").and_then(|v| v.as_array()) { if let Some(first_tool_call) = tool_calls_arr.first() { let func_name = first_tool_call .get("name") .and_then(|v| v.as_str()) .unwrap_or("no_name"); let tool_response = json!({ "name": func_name, "result": content, }); content = format!( "\n{}\n", serde_json::to_string(&tool_response)? ); } } } } } } } processed_messages.push(Message { role, content: MessageContent::Text(content), name: message.name.clone(), tool_calls: None, tool_call_id: None, }); } // Ensure last message is from user if let Some(last) = processed_messages.last() { if last.role != Role::User { return Err(FunctionCallingError::InvalidModelResponse( "Last message must be from user".to_string(), )); } } // Add extra instruction if provided if let Some(instruction) = extra_instruction { if let Some(last) = processed_messages.last_mut() { if let MessageContent::Text(content) = &mut last.content { content.push('\n'); content.push_str(instruction); } } } // Truncate messages if they exceed max_tokens let processed_messages = self.truncate_messages(processed_messages, max_tokens); Ok(processed_messages) } /// Truncates messages to fit within max_tokens limit fn truncate_messages(&self, messages: Vec, max_tokens: usize) -> Vec { let mut num_tokens = 0; let mut conversation_idx = 0; // Keep system message if present if let Some(first) = messages.first() { if first.role == Role::System { if let MessageContent::Text(content) = &first.content { num_tokens += content.len() / 4; // Approximate 4 chars per token } conversation_idx = 1; } } // Calculate from the end backwards // Start with message_idx pointing past the end (will be used if no truncation needed) let mut message_idx = messages.len(); for i in (conversation_idx..messages.len()).rev() { if let MessageContent::Text(content) = &messages[i].content { num_tokens += content.len() / 4; if num_tokens >= max_tokens && messages[i].role == Role::User { // Set message_idx to current position and break // This matches Python's behavior where message_idx is set before break message_idx = i; break; } } // Only update message_idx if we haven't hit the token limit yet // This ensures message_idx points to where truncation should start if num_tokens < max_tokens { message_idx = i; } } // Return system message + truncated conversation let mut result = Vec::new(); if conversation_idx > 0 { result.push(messages[0].clone()); } result.extend_from_slice(&messages[message_idx..]); result } /// Prefills a message by adding an assistant message with the prefix pub fn prefill_message(&self, mut messages: Vec, prefill: &str) -> Vec { messages.push(Message { role: Role::Assistant, content: MessageContent::Text(prefill.to_string()), name: None, tool_calls: None, tool_call_id: None, }); messages } /// Helper to create a request with VLLM-specific parameters fn create_request_with_extra_body( &self, messages: Vec, stream: bool, ) -> ChatCompletionsRequest { ChatCompletionsRequest { model: self.model_name.clone(), messages, temperature: Some(self.config.generation_params.temperature), top_p: Some(self.config.generation_params.top_p), max_tokens: Some(self.config.generation_params.max_tokens), stream: Some(stream), logprobs: self.config.generation_params.logprobs, top_logprobs: self.config.generation_params.top_logprobs, // VLLM-specific parameters continue_final_message: Some(true), add_generation_prompt: Some(false), top_k: Some(self.config.generation_params.top_k), stop_token_ids: if !self.config.generation_params.stop_token_ids.is_empty() { Some(self.config.generation_params.stop_token_ids.clone()) } else { None }, ..Default::default() } } /// Makes a streaming request and returns the SSE event stream async fn make_streaming_request( &self, request: ChatCompletionsRequest, ) -> Result< std::pin::Pin> + Send>>, > { let request_body = serde_json::to_string(&request).map_err(|e| { FunctionCallingError::InvalidModelResponse(format!( "Failed to serialize request: {}", e )) })?; let response = self .http_client .post(&self.endpoint_url) .header("Content-Type", "application/json") .body(request_body) .send() .await .map_err(FunctionCallingError::HttpError)?; if !response.status().is_success() { let status = response.status(); let error_text = response .text() .await .unwrap_or_else(|_| "Unknown error".to_string()); return Err(FunctionCallingError::InvalidModelResponse(format!( "HTTP error {}: {}", status, error_text ))); } // Parse SSE stream let stream = response.bytes_stream().eventsource(); let parsed_stream = stream.filter_map(|event_result| async move { match event_result { Ok(event) => { // Skip [DONE] sentinel if event.data == "[DONE]" { return None; } // Parse JSON match serde_json::from_str::(&event.data) { Ok(json) => Some(Ok(json)), Err(e) => Some(Err(format!("JSON parse error: {}", e))), } } Err(e) => Some(Err(format!("SSE stream error: {}", e))), } }); Ok(Box::pin(parsed_stream)) } /// Makes a non-streaming request and returns the response async fn make_non_streaming_request( &self, request: ChatCompletionsRequest, ) -> Result { let request_body = serde_json::to_string(&request).map_err(|e| { FunctionCallingError::InvalidModelResponse(format!( "Failed to serialize request: {}", e )) })?; let response = self .http_client .post(&self.endpoint_url) .header("Content-Type", "application/json") .body(request_body) .send() .await .map_err(FunctionCallingError::HttpError)?; if !response.status().is_success() { let status = response.status(); let error_text = response .text() .await .unwrap_or_else(|_| "Unknown error".to_string()); return Err(FunctionCallingError::InvalidModelResponse(format!( "HTTP error {}: {}", status, error_text ))); } let response_text = response .text() .await .map_err(FunctionCallingError::HttpError)?; serde_json::from_str(&response_text).map_err(FunctionCallingError::JsonParseError) } pub async fn function_calling_chat( &self, request: ChatCompletionsRequest, ) -> Result { use tracing::{error, info}; info!("[Arch-Function] - ChatCompletion"); let messages = self.process_messages( &request.messages, request.tools.as_deref(), None, self.config.generation_params.max_tokens as usize, request.metadata.as_ref(), )?; info!( "[request to arch-fc]: model: {}, messages count: {}", self.model_name, messages.len() ); let use_agent_orchestrator = request .metadata .as_ref() .and_then(|m| m.get("use_agent_orchestrator")) .and_then(|v| v.as_bool()) .unwrap_or(false); let prefilled_messages = self.prefill_message(messages.clone(), &self.default_prefix); // Create request with extra_body parameters let stream_request = self.create_request_with_extra_body(prefilled_messages.clone(), true); let mut stream = self.make_streaming_request(stream_request).await?; let mut model_response = String::new(); if use_agent_orchestrator { while let Some(chunk_result) = stream.next().await { let chunk = chunk_result.map_err(FunctionCallingError::InvalidModelResponse)?; // Extract content from JSON response if let Some(choices) = chunk.get("choices").and_then(|v| v.as_array()) { if let Some(choice) = choices.first() { if let Some(content) = choice .get("delta") .and_then(|d| d.get("content")) .and_then(|c| c.as_str()) { model_response.push_str(content); } } } } info!("[Agent Orchestrator]: response received"); } else if let Some(tools) = request.tools.as_ref() { let mut hallucination_state = HallucinationState::new(tools); let mut has_tool_calls = None; let mut has_hallucination = false; while let Some(chunk_result) = stream.next().await { let chunk = chunk_result.map_err(FunctionCallingError::InvalidModelResponse)?; // Extract content and logprobs from JSON response if let Some(choices) = chunk.get("choices").and_then(|v| v.as_array()) { if let Some(choice) = choices.first() { if let Some(content) = choice .get("delta") .and_then(|d| d.get("content")) .and_then(|c| c.as_str()) { // Extract logprobs let logprobs: Vec = choice .get("logprobs") .and_then(|lp| lp.get("content")) .and_then(|c| c.as_array()) .and_then(|arr| arr.first()) .and_then(|token| token.get("top_logprobs")) .and_then(|tlp| tlp.as_array()) .map(|arr| { arr.iter() .filter_map(|v| v.get("logprob").and_then(|lp| lp.as_f64())) .collect() }) .unwrap_or_default(); if hallucination_state .append_and_check_token_hallucination(content.to_string(), logprobs) { has_hallucination = true; break; } if hallucination_state.tokens.len() > 5 && has_tool_calls.is_none() { let collected_content = hallucination_state.tokens.join(""); has_tool_calls = Some(collected_content.contains("tool_calls")); } } } } } if has_tool_calls == Some(true) && has_hallucination { info!("[Hallucination]: {}", hallucination_state.error_message); let clarify_messages = self.prefill_message(messages.clone(), &self.clarify_prefix); let clarify_request = self.create_request_with_extra_body(clarify_messages, false); let retry_response = self.make_non_streaming_request(clarify_request).await?; if let Some(choice) = retry_response.choices.first() { if let Some(content) = &choice.message.content { model_response = content.clone(); } } } else { model_response = hallucination_state.tokens.join(""); } } else { while let Some(chunk_result) = stream.next().await { let chunk = chunk_result.map_err(FunctionCallingError::InvalidModelResponse)?; if let Some(choices) = chunk.get("choices").and_then(|v| v.as_array()) { if let Some(choice) = choices.first() { if let Some(content) = choice .get("delta") .and_then(|d| d.get("content")) .and_then(|c| c.as_str()) { model_response.push_str(content); } } } } } let response_dict = self.parse_model_response(&model_response); info!( "[arch-fc]: raw model response: {}", response_dict.raw_response ); // General model response (no intent matched - should route to default target) let model_message = if response_dict .response .as_ref() .is_some_and(|s| !s.is_empty()) { // When arch-fc returns a "response" field, it means no intent was matched // Return empty content and empty tool_calls so prompt_gateway routes to default target ResponseMessage { role: Role::Assistant, content: Some(String::new()), refusal: None, annotations: None, audio: None, function_call: None, tool_calls: None, } } else if !response_dict.required_functions.is_empty() { if !use_agent_orchestrator { ResponseMessage { role: Role::Assistant, content: Some(response_dict.clarification.clone()), refusal: None, annotations: None, audio: None, function_call: None, tool_calls: None, } } else { ResponseMessage { role: Role::Assistant, content: Some(String::new()), refusal: None, annotations: None, audio: None, function_call: None, tool_calls: None, } } } else if !response_dict.tool_calls.is_empty() { if response_dict.is_valid { if !use_agent_orchestrator { if let Some(tools) = request.tools.as_ref() { let verification = self.verify_tool_calls(tools, &response_dict.tool_calls); if verification.is_valid { info!( "[Tool calls]: {:?}", response_dict .tool_calls .iter() .map(|tc| &tc.function) .collect::>() ); ResponseMessage { role: Role::Assistant, content: Some(String::new()), refusal: None, annotations: None, audio: None, function_call: None, tool_calls: Some(response_dict.tool_calls.clone()), } } else { error!("Invalid tool call - {}", verification.error_message); ResponseMessage { role: Role::Assistant, content: Some(String::new()), refusal: None, annotations: None, audio: None, function_call: None, tool_calls: None, } } } else { error!("Tool calls present but no tools provided in request"); ResponseMessage { role: Role::Assistant, content: Some(String::new()), refusal: None, annotations: None, audio: None, function_call: None, tool_calls: None, } } } else { info!( "[Tool calls]: {:?}", response_dict .tool_calls .iter() .map(|tc| &tc.function) .collect::>() ); ResponseMessage { role: Role::Assistant, content: Some(String::new()), refusal: None, annotations: None, audio: None, function_call: None, tool_calls: Some(response_dict.tool_calls.clone()), } } } else { error!( "Invalid tool calls in response: {}", response_dict.error_message ); ResponseMessage { role: Role::Assistant, content: Some(String::new()), refusal: None, annotations: None, audio: None, function_call: None, tool_calls: None, } } } else { error!("Invalid model response - {}", model_response); ResponseMessage { role: Role::Assistant, content: Some(String::new()), refusal: None, annotations: None, audio: None, function_call: None, tool_calls: None, } }; // Create metadata with the raw model response let mut metadata = HashMap::new(); metadata.insert( "x-arch-fc-model-response".to_string(), serde_json::to_value(&response_dict.raw_response) .unwrap_or_else(|_| Value::String(response_dict.raw_response.clone())), ); let chat_completion_response = ChatCompletionsResponse { id: format!("chatcmpl-{}", uuid::Uuid::new_v4()), object: Some("chat.completion".to_string()), created: chrono::Utc::now().timestamp() as u64, model: request.model.clone(), choices: vec![Choice { index: 0, message: model_message, finish_reason: Some(FinishReason::Stop), logprobs: None, }], usage: Usage { prompt_tokens: 0, completion_tokens: 0, total_tokens: 0, prompt_tokens_details: None, completion_tokens_details: None, }, system_fingerprint: None, service_tier: None, metadata: Some(metadata), }; info!("[response arch-fc]: {:?}", chat_completion_response); Ok(chat_completion_response) } } // ============================================================================ // ARCH AGENT HANDLER // ============================================================================ /// Handler for Arch Agent (extends ArchFunctionHandler with specialized behavior) pub struct ArchAgentHandler { pub function_handler: ArchFunctionHandler, } impl ArchAgentHandler { /// Creates a new ArchAgentHandler pub fn new(model_name: String, endpoint_url: String) -> Self { let config = ArchAgentConfig::default(); Self { function_handler: ArchFunctionHandler::new( model_name, ArchFunctionConfig { task_prompt: config.task_prompt, format_prompt: config.format_prompt, generation_params: GenerationParams { temperature: config.generation_params.temperature, top_p: config.generation_params.top_p, top_k: config.generation_params.top_k, max_tokens: config.generation_params.max_tokens, stop_token_ids: config.generation_params.stop_token_ids, logprobs: config.generation_params.logprobs, top_logprobs: config.generation_params.top_logprobs, }, support_data_types: config.support_data_types, }, endpoint_url, ), } } /// Converts tools with special handling for empty parameters /// This is the key difference from ArchFunctionHandler pub fn convert_tools(&self, tools: &[Tool]) -> Result { let mut converted = Vec::new(); for tool in tools { let mut tool_copy = tool.clone(); // Delete parameters key if its empty if let Some(props) = tool_copy.function.parameters.get("properties") { if props.is_object() && props.as_object().unwrap().is_empty() { // Create new parameters without properties if let Some(params_obj) = tool_copy.function.parameters.as_object_mut() { params_obj.remove("properties"); } } } converted.push(serde_json::to_string(&tool_copy.function)?); } Ok(converted.join("\n")) } } // ============================================================================ // HTTP HANDLER FOR FUNCTION CALLING ENDPOINT // ============================================================================ fn full>(chunk: T) -> BoxBody { Full::new(chunk.into()) .map_err(|never| match never {}) .boxed() } pub async fn function_calling_chat_handler( req: Request, llm_provider_url: String, ) -> std::result::Result>, hyper::Error> { use hermesllm::apis::openai::ChatCompletionsRequest; let whole_body = req.collect().await?.to_bytes(); // Parse as JSON Value first to modify it let mut body_json: Value = match serde_json::from_slice(&whole_body) { Ok(json) => json, Err(e) => { error!("Failed to parse request body as JSON: {}", e); let mut response = Response::new(full( serde_json::json!({ "error": format!("Invalid request body: {}", e) }) .to_string(), )); *response.status_mut() = StatusCode::BAD_REQUEST; response .headers_mut() .insert("Content-Type", "application/json".parse().unwrap()); return Ok(response); } }; // Add "model": "Arch-Function" to the request if let Some(obj) = body_json.as_object_mut() { obj.insert("model".to_string(), ARCH_FUNCTION_MODEL_NAME.into()); } // Parse as ChatCompletionsRequest let chat_request: ChatCompletionsRequest = match serde_json::from_value(body_json) { Ok(req) => { info!( "[request body]: {}", serde_json::to_string(&req).unwrap_or_default() ); req } Err(e) => { error!("Failed to parse request body: {}", e); let mut response = Response::new(full( serde_json::json!({ "error": format!("Invalid request body: {}", e) }) .to_string(), )); *response.status_mut() = StatusCode::BAD_REQUEST; response .headers_mut() .insert("Content-Type", "application/json".parse().unwrap()); return Ok(response); } }; // Determine which handler to use based on metadata let use_agent_orchestrator = chat_request .metadata .as_ref() .and_then(|m| m.get("use_agent_orchestrator")) .and_then(|v| v.as_bool()) .unwrap_or(false); info!("Use agent orchestrator: {}", use_agent_orchestrator); // Create the appropriate handler let handler_name = if use_agent_orchestrator { "Arch-Agent" } else { "Arch-Function" }; // Call the handler let final_response = if use_agent_orchestrator { let handler = ArchAgentHandler::new( ARCH_FUNCTION_MODEL_NAME.to_string(), llm_provider_url.clone(), ); handler .function_handler .function_calling_chat(chat_request) .await } else { let handler = ArchFunctionHandler::new( ARCH_FUNCTION_MODEL_NAME.to_string(), ArchFunctionConfig::default(), llm_provider_url.clone(), ); handler.function_calling_chat(chat_request).await }; match final_response { Ok(response_data) => { let response_json = serde_json::to_string(&response_data).unwrap_or_else(|e| { error!("Failed to serialize response: {}", e); serde_json::json!({"error": "Failed to serialize response"}).to_string() }); let mut response = Response::new(full(response_json)); *response.status_mut() = StatusCode::OK; response .headers_mut() .insert("Content-Type", "application/json".parse().unwrap()); Ok(response) } Err(e) => { error!("[{}] - Error in function calling: {}", handler_name, e); let error_response = serde_json::json!({ "error": format!("[{}] - Error in function calling: {}", handler_name, e) }); let mut response = Response::new(full(error_response.to_string())); *response.status_mut() = StatusCode::INTERNAL_SERVER_ERROR; response .headers_mut() .insert("Content-Type", "application/json".parse().unwrap()); Ok(response) } } } // ============================================================================ // TESTS // ============================================================================ #[cfg(test)] mod tests { use super::*; #[test] fn test_arch_function_config_default() { let config = ArchFunctionConfig::default(); assert!(config.task_prompt.contains("helpful assistant")); assert!(config.format_prompt.contains("JSON formats")); assert_eq!(config.generation_params.temperature, 0.1); assert_eq!(config.support_data_types.len(), 14); // 8 Python-style + 6 JSON Schema names // Verify prompt formatting for literal escaped newlines ("\\n") instead of actual newline chars // The user requirement changed prompts to display "\\n" sequences literally. assert!(config.task_prompt.contains("\\n\\nYou are provided")); assert!(config.task_prompt.contains("\\n\\n")); // Format prompt should contain literal escaped newlines and proper JSON examples assert!(config .format_prompt .contains("\\n\\nBased on your analysis")); assert!(config .format_prompt .contains(r#"{\"response\": \"Your response text here\"}"#)); assert!(config.format_prompt.contains(r#"{\"tool_calls\": [{"#)); } #[test] fn test_arch_agent_config_default() { let config = ArchAgentConfig::default(); assert_eq!(config.generation_params.temperature, 0.01); // Different from ArchFunctionConfig } #[test] fn test_fix_json_string_valid() { let handler = ArchFunctionHandler::new( "test-model".to_string(), ArchFunctionConfig::default(), "http://localhost:8000".to_string(), ); let json_str = r#"{"name": "test", "value": 123}"#; let result = handler.fix_json_string(json_str); assert!(result.is_ok()); } #[test] fn test_fix_json_string_missing_bracket() { let handler = ArchFunctionHandler::new( "test-model".to_string(), ArchFunctionConfig::default(), "http://localhost:8000".to_string(), ); let json_str = r#"{"name": "test", "value": 123"#; let result = handler.fix_json_string(json_str); assert!(result.is_ok()); let fixed = result.unwrap(); assert!(fixed.contains("}")); } #[test] fn test_parse_model_response_with_tool_calls() { let handler = ArchFunctionHandler::new( "test-model".to_string(), ArchFunctionConfig::default(), "http://localhost:8000".to_string(), ); let content = r#"{"tool_calls": [{"name": "get_weather", "arguments": {"location": "NYC"}}]}"#; let result = handler.parse_model_response(content); assert!(result.is_valid); assert_eq!(result.tool_calls.len(), 1); assert_eq!(result.tool_calls[0].function.name, "get_weather"); } #[test] fn test_parse_model_response_with_clarification() { let handler = ArchFunctionHandler::new( "test-model".to_string(), ArchFunctionConfig::default(), "http://localhost:8000".to_string(), ); let content = r#"{"required_functions": ["get_weather"], "clarification": "What location?"}"#; let result = handler.parse_model_response(content); assert!(result.is_valid); assert_eq!(result.required_functions.len(), 1); assert_eq!(result.clarification, "What location?"); } #[test] fn test_convert_data_type_int_to_float() { let handler = ArchFunctionHandler::new( "test-model".to_string(), ArchFunctionConfig::default(), "http://localhost:8000".to_string(), ); let value = json!(42); let result = handler.convert_data_type(&value, "float"); assert!(result.is_ok()); assert!(result.unwrap().is_f64()); } } // ============================================================================ // HALLUCINATION DETECTION MODULE // ============================================================================ /// Mask token types for tracking parsing state #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum MaskToken { FunctionName, ParameterValue, ParameterName, NotUsed, ToolCall, } /// Uncertainty metrics calculated from log probabilities #[derive(Debug, Clone)] pub struct UncertaintyMetrics { pub entropy: f64, pub varentropy: f64, pub probability: f64, } /// Calculates uncertainty metrics from log probabilities /// /// This is a simplified Rust implementation that avoids torch/tensor dependencies. /// Uses basic statistical calculations instead of tensor operations. pub fn calculate_uncertainty(log_probs: &[f64]) -> UncertaintyMetrics { if log_probs.is_empty() { return UncertaintyMetrics { entropy: 0.0, varentropy: 0.0, probability: 0.0, }; } // Convert log probabilities to probabilities let token_probs: Vec = log_probs.iter().map(|&lp| lp.exp()).collect(); // Calculate entropy: -sum(p * log(p)) / log(2) let mut entropy = 0.0; for i in 0..log_probs.len() { entropy -= log_probs[i] * token_probs[i]; } entropy /= 2_f64.ln(); // Convert to bits // Calculate variance of entropy let mut varentropy = 0.0; for i in 0..log_probs.len() { let diff = log_probs[i] / 2_f64.ln() + entropy; varentropy += token_probs[i] * diff * diff; } // Get the top probability let probability = token_probs.first().copied().unwrap_or(0.0); UncertaintyMetrics { entropy, varentropy, probability, } } /// Checks if uncertainty metrics exceed thresholds pub fn check_threshold( entropy: f64, varentropy: f64, thresholds: &HallucinationThresholds, ) -> bool { entropy > thresholds.entropy && varentropy > thresholds.varentropy } /// Checks if a parameter is required in the function description pub fn is_parameter_required(function_description: &Value, parameter_name: &str) -> bool { if let Some(required) = function_description.get("required") { if let Some(required_arr) = required.as_array() { return required_arr .iter() .any(|v| v.as_str() == Some(parameter_name)); } } false } /// Checks if a parameter has a specific property pub fn is_parameter_property( function_description: &Value, parameter_name: &str, property_name: &str, ) -> bool { if let Some(properties) = function_description.get("properties") { if let Some(param_info) = properties.get(parameter_name) { return param_info.get(property_name).is_some(); } } false } /// State for hallucination detection during streaming /// /// This is a simplified version of the Python HallucinationState that doesn't /// require torch/tensor dependencies. It provides the core functionality needed /// for detecting hallucinations during function calling. #[derive(Debug)] pub struct HallucinationState { pub tokens: Vec, pub logprobs: Vec>, pub state: Option, pub mask: Vec, pub parameter_name_done: bool, pub hallucination: bool, pub error_message: String, pub parameter_name: Vec, pub token_probs_map: Vec<(String, f64, f64, f64)>, pub function_properties: HashMap, pub open_bracket: bool, pub bracket: Option, pub function_name: String, pub check_parameter_name: HashMap, pub thresholds: HallucinationThresholds, } impl HallucinationState { /// Creates a new HallucinationState with function definitions pub fn new(functions: &[Tool]) -> Self { let function_properties: HashMap = functions .iter() .map(|tool| (tool.function.name.clone(), tool.function.parameters.clone())) .collect(); Self { tokens: Vec::new(), logprobs: Vec::new(), state: None, mask: Vec::new(), parameter_name_done: false, hallucination: false, error_message: String::new(), parameter_name: Vec::new(), token_probs_map: Vec::new(), function_properties, open_bracket: false, bracket: None, function_name: String::new(), check_parameter_name: HashMap::new(), thresholds: HallucinationThresholds::default(), } } /// Appends a token and checks for hallucination pub fn append_and_check_token_hallucination( &mut self, token: String, logprob: Vec, ) -> bool { self.tokens.push(token); self.logprobs.push(logprob); self.process_token(); self.hallucination } /// Resets internal parameters fn reset_parameters(&mut self) { self.state = None; self.parameter_name_done = false; self.hallucination = false; self.error_message.clear(); self.open_bracket = false; self.bracket = None; self.check_parameter_name.clear(); } /// Processes the current token and updates state fn process_token(&mut self) { let content: String = self.tokens.join("").replace(' ', ""); // Handle end of tool call if content.ends_with(END_TOOL_CALL_TOKEN) { self.reset_parameters(); } // Function name extraction logic if self.state.as_deref() == Some("function_name") { if !FUNC_NAME_END_TOKEN .iter() .any(|&t| self.tokens.last().is_some_and(|tok| tok == t)) { self.mask.push(MaskToken::FunctionName); } else { self.state = None; self.get_function_name(); } } // Check for function name start if FUNC_NAME_START_PATTERN .iter() .any(|&p| content.ends_with(p)) { self.state = Some("function_name".to_string()); } // Parameter name extraction logic if self.state.as_deref() == Some("parameter_name") && !PARAMETER_NAME_END_TOKENS .iter() .any(|&t| content.ends_with(t)) { self.mask.push(MaskToken::ParameterName); } else if self.state.as_deref() == Some("parameter_name") && PARAMETER_NAME_END_TOKENS .iter() .any(|&t| content.ends_with(t)) { self.state = None; self.parameter_name_done = true; self.get_parameter_name(); } else if self.parameter_name_done && !self.open_bracket && PARAMETER_NAME_START_PATTERN .iter() .any(|&p| content.ends_with(p)) { self.state = Some("parameter_name".to_string()); } // First parameter value start if FIRST_PARAM_NAME_START_PATTERN .iter() .any(|&p| content.ends_with(p)) { self.state = Some("parameter_name".to_string()); } // Parameter value extraction logic if self.state.as_deref() == Some("parameter_value") && !PARAMETER_VALUE_END_TOKEN .iter() .any(|&t| content.ends_with(t)) { // Check for brackets if let Some(last_token) = self.tokens.last() { let open_brackets: Vec = last_token .trim() .chars() .filter(|&c| c == '(' || c == '{' || c == '[') .collect(); if !open_brackets.is_empty() { self.open_bracket = true; self.bracket = Some(open_brackets[0]); } if self.open_bracket { let closing = match self.bracket { Some('(') => ')', Some('{') => '}', Some('[') => ']', _ => '\0', }; if last_token.trim().contains(closing) { self.open_bracket = false; self.bracket = None; } } // Check if token has actual value content let has_non_punct = last_token.trim().chars().any(|c| !c.is_ascii_punctuation()); if has_non_punct && !last_token.trim().is_empty() { self.mask.push(MaskToken::ParameterValue); // Check hallucination for required parameters without enum if self.function_properties.contains_key(&self.function_name) { if self.mask.len() > 1 && self.mask[self.mask.len() - 2] != MaskToken::ParameterValue && !self.parameter_name.is_empty() { let last_param = self.parameter_name[self.parameter_name.len() - 1].clone(); if let Some(func_props) = self.function_properties.get(&self.function_name) { if is_parameter_required(func_props, &last_param) && !is_parameter_property(func_props, &last_param, "enum") && !self.check_parameter_name.contains_key(&last_param) { self.check_logprob(); self.check_parameter_name.insert(last_param, true); } } } } else if !self.function_name.is_empty() { self.check_logprob(); self.error_message = format!( "Function name {} not found in function properties", self.function_name ); } } else { self.mask.push(MaskToken::NotUsed); } } } else if self.state.as_deref() == Some("parameter_value") && !self.open_bracket && PARAMETER_VALUE_END_TOKEN .iter() .any(|&t| content.ends_with(t)) { self.state = None; } else if self.parameter_name_done && PARAMETER_VALUE_START_PATTERN .iter() .any(|&p| content.ends_with(p)) { self.state = Some("parameter_value".to_string()); } // Maintain consistency between tokens and mask if self.mask.len() != self.tokens.len() { self.mask.push(MaskToken::NotUsed); } } /// Checks log probability and detects hallucination fn check_logprob(&mut self) { if let Some(probs) = self.logprobs.last() { let metrics = calculate_uncertainty(probs); if let Some(token) = self.tokens.last() { self.token_probs_map.push(( token.clone(), metrics.entropy, metrics.varentropy, metrics.probability, )); if check_threshold(metrics.entropy, metrics.varentropy, &self.thresholds) { self.hallucination = true; self.error_message = format!( "token '{}' is uncertain. Generated response:\n{}", token, self.tokens.join("") ); } } } } /// Counts consecutive tokens of a specific type in the mask fn count_consecutive_token(&self, token_type: MaskToken) -> usize { if self.mask.is_empty() || self.mask.last() != Some(&token_type) { return 0; } self.mask .iter() .rev() .take_while(|&&t| t == token_type) .count() } /// Extracts the parameter name from recent tokens fn get_parameter_name(&mut self) { let p_len = self.count_consecutive_token(MaskToken::ParameterName); if p_len > 0 && self.tokens.len() > 1 { let start_idx = self.tokens.len().saturating_sub(p_len + 1); let end_idx = self.tokens.len().saturating_sub(1); let parameter_name: String = self.tokens[start_idx..end_idx].join(""); self.parameter_name.push(parameter_name); } } /// Extracts the function name from recent tokens fn get_function_name(&mut self) { let f_len = self.count_consecutive_token(MaskToken::FunctionName); if f_len > 0 && self.tokens.len() > 1 { let start_idx = self.tokens.len().saturating_sub(f_len + 1); let end_idx = self.tokens.len().saturating_sub(1); self.function_name = self.tokens[start_idx..end_idx].join(""); } } } #[cfg(test)] mod hallucination_tests { use super::*; #[test] fn test_calculate_uncertainty() { let log_probs = vec![-0.1, -2.0, -3.0]; let metrics = calculate_uncertainty(&log_probs); assert!(metrics.entropy >= 0.0); assert!(metrics.varentropy >= 0.0); assert!(metrics.probability > 0.0 && metrics.probability <= 1.0); } #[test] fn test_calculate_uncertainty_empty() { let log_probs: Vec = vec![]; let metrics = calculate_uncertainty(&log_probs); assert_eq!(metrics.entropy, 0.0); assert_eq!(metrics.varentropy, 0.0); assert_eq!(metrics.probability, 0.0); } #[test] fn test_check_threshold() { let thresholds = HallucinationThresholds::default(); assert!(check_threshold(0.001, 0.001, &thresholds)); assert!(!check_threshold(0.00001, 0.00001, &thresholds)); } #[test] fn test_is_parameter_required() { let func_desc = json!({ "required": ["param1", "param2"] }); assert!(is_parameter_required(&func_desc, "param1")); assert!(!is_parameter_required(&func_desc, "param3")); } #[test] fn test_is_parameter_property() { let func_desc = json!({ "properties": { "param1": { "type": "string", "enum": ["a", "b"] } } }); assert!(is_parameter_property(&func_desc, "param1", "enum")); assert!(!is_parameter_property(&func_desc, "param1", "default")); } #[test] fn test_check_value_type() { let handler = ArchFunctionHandler::new( "test-model".to_string(), ArchFunctionConfig::default(), "http://localhost:8000".to_string(), ); // Test integer types assert!(handler.check_value_type(&json!(42), "integer")); assert!(handler.check_value_type(&json!(42), "int")); assert!(!handler.check_value_type(&json!(3.15), "integer")); // Test number types (accepts both int and float) assert!(handler.check_value_type(&json!(3.15), "number")); assert!(handler.check_value_type(&json!(42), "number")); assert!(handler.check_value_type(&json!(3.15), "float")); // Test boolean assert!(handler.check_value_type(&json!(true), "boolean")); assert!(handler.check_value_type(&json!(false), "bool")); assert!(!handler.check_value_type(&json!("true"), "boolean")); // Test string assert!(handler.check_value_type(&json!("hello"), "string")); assert!(handler.check_value_type(&json!("hello"), "str")); assert!(!handler.check_value_type(&json!(123), "string")); // Test array assert!(handler.check_value_type(&json!([1, 2, 3]), "array")); assert!(handler.check_value_type(&json!([1, 2, 3]), "list")); assert!(!handler.check_value_type(&json!({}), "array")); // Test object assert!(handler.check_value_type(&json!({"key": "value"}), "object")); assert!(handler.check_value_type(&json!({"key": "value"}), "dict")); assert!(!handler.check_value_type(&json!([]), "object")); // Test unknown type (should return true) assert!(handler.check_value_type(&json!(42), "unknown_type")); } #[test] fn test_validate_or_convert_parameter() { let handler = ArchFunctionHandler::new( "test-model".to_string(), ArchFunctionConfig::default(), "http://localhost:8000".to_string(), ); // Test valid type - no conversion needed assert!(handler .validate_or_convert_parameter(&json!(42), "integer") .unwrap()); assert!(handler .validate_or_convert_parameter(&json!("hello"), "string") .unwrap()); // Test integer to float conversion (convert_data_type supports this) let result = handler.validate_or_convert_parameter(&json!(42), "float"); assert!(result.is_ok()); assert!(result.unwrap()); // Should be valid after conversion // Test invalid type that cannot be converted // A string cannot be converted to integer (convert_data_type doesn't support this) let result = handler.validate_or_convert_parameter(&json!("abc"), "integer"); // Since convert_data_type returns Ok(value.clone()) for unsupported conversions, // the validation will fail because "abc" string is not an integer assert!(!result.unwrap()); // Test number accepting both int and float assert!(handler .validate_or_convert_parameter(&json!(42), "number") .unwrap()); assert!(handler .validate_or_convert_parameter(&json!(3.15), "number") .unwrap()); } #[test] fn test_hallucination_state_new() { let tools = vec![Tool { tool_type: "function".to_string(), function: hermesllm::apis::openai::Function { name: "test_func".to_string(), description: Some("Test function".to_string()), parameters: json!({"type": "object"}), strict: None, }, }]; let state = HallucinationState::new(&tools); assert_eq!(state.tokens.len(), 0); assert!(!state.hallucination); assert!(state.function_properties.contains_key("test_func")); } }