2025-12-25 21:08:37 -08:00
use bytes ::Bytes ;
use eventsource_stream ::Eventsource ;
use futures ::StreamExt ;
2025-11-22 12:55:00 -08:00
use hermesllm ::apis ::openai ::{
ChatCompletionsRequest , ChatCompletionsResponse , Choice , FinishReason , FunctionCall , Message ,
MessageContent , ResponseMessage , Role , Tool , ToolCall , Usage ,
} ;
2025-12-25 21:08:37 -08:00
use http_body_util ::{ combinators ::BoxBody , BodyExt , Full } ;
use hyper ::body ::Incoming ;
use hyper ::{ Request , Response , StatusCode } ;
2025-11-22 12:55:00 -08:00
use serde ::{ Deserialize , Serialize } ;
use serde_json ::{ json , Value } ;
use std ::collections ::HashMap ;
use thiserror ::Error ;
2025-12-25 21:08:37 -08:00
use tracing ::{ error , info } ;
2025-11-22 12:55:00 -08:00
// ============================================================================
// CONSTANTS FOR HALLUCINATION DETECTION
// ============================================================================
const FUNC_NAME_START_PATTERN : & [ & str ] = & [ r # "{"name":""# , r # "{'name':'"# ] ;
const FUNC_NAME_END_TOKEN : & [ & str ] = & [ " \" , " , " ', " ] ;
const END_TOOL_CALL_TOKEN : & str = " }} " ;
const FIRST_PARAM_NAME_START_PATTERN : & [ & str ] = & [ r # ""arguments":{"# , r # "'arguments':{'"# ] ;
const PARAMETER_NAME_END_TOKENS : & [ & str ] = & [ " \" : " , " : \" " , " ': " , " :' " , " \" : \" " , " ':' " ] ;
const PARAMETER_NAME_START_PATTERN : & [ & str ] = & [ " \" , \" " , " ',' " ] ;
const PARAMETER_VALUE_START_PATTERN : & [ & str ] = & [ " \" : " , " ': " ] ;
const PARAMETER_VALUE_END_TOKEN : & [ & str ] = & [ " \" , " , " \" } " ] ;
const ARCH_FUNCTION_MODEL_NAME : & str = " Arch-Function " ;
/// Default hallucination detection thresholds
#[ derive(Debug, Clone) ]
pub struct HallucinationThresholds {
pub entropy : f64 ,
pub varentropy : f64 ,
pub probability : f64 ,
}
impl Default for HallucinationThresholds {
fn default ( ) -> Self {
Self {
entropy : 0.0001 ,
varentropy : 0.0001 ,
probability : 0.8 ,
}
}
}
// ============================================================================
// ERROR TYPES
// ============================================================================
#[ derive(Debug, Error) ]
pub enum FunctionCallingError {
#[ error( " Failed to parse JSON: {0} " ) ]
JsonParseError ( #[ from ] serde_json ::Error ) ,
#[ error( " Failed to fix malformed JSON: {0} " ) ]
JsonFixError ( String ) ,
#[ error( " Invalid model response: {0} " ) ]
InvalidModelResponse ( String ) ,
#[ error( " Tool call verification failed: {0} " ) ]
ToolCallVerificationError ( String ) ,
#[ error( " Data type conversion error: {0} " ) ]
DataTypeConversionError ( String ) ,
#[ error( " Unsupported data type: {0} " ) ]
UnsupportedDataType ( String ) ,
#[ error( " HTTP request error: {0} " ) ]
HttpError ( #[ from ] reqwest ::Error ) ,
#[ error( " Invalid tool call: {0} " ) ]
InvalidToolCall ( String ) ,
}
pub type Result < T > = std ::result ::Result < T , FunctionCallingError > ;
// ============================================================================
// CONFIGURATION STRUCTURES
// ============================================================================
/// Configuration for Arch Function Calling
#[ derive(Debug, Clone) ]
pub struct ArchFunctionConfig {
pub task_prompt : String ,
pub format_prompt : String ,
pub generation_params : GenerationParams ,
pub support_data_types : Vec < String > ,
}
impl Default for ArchFunctionConfig {
fn default ( ) -> Self {
Self {
// Raw string so that \n sequences remain literal in the final prompt
task_prompt : r #" You are a helpful assistant designed to assist with the user query by making one or more function calls if needed. \n \n You are provided with function signatures within <tools></tools> XML tags: \n <tools> \n {tools} \n </tools> \n \n Your task is to decide which functions are needed and collect missing parameters if necessary. " #. to_string ( ) ,
// Use raw string to preserve literal \n sequences instead of real newlines
format_prompt : r #" \n \n Based on your analysis, provide your response in one of the following JSON formats: \n 1. If no functions are needed: \n ```json \n { \" response \" : \" Your response text here \" } \n ``` \n 2. If functions are needed but some required parameters are missing: \n ```json \n { \" required_functions \" : [ \" func_name1 \" , \" func_name2 \" , ...], \" clarification \" : \" Text asking for missing parameters \" } \n ``` \n 3. If functions are needed and all required parameters are available: \n ```json \n { \" tool_calls \" : [{ \" name \" : \" func_name1 \" , \" arguments \" : { \" argument1 \" : \" value1 \" , \" argument2 \" : \" value2 \" }},... (more tool calls as required)]} \n ``` " #. to_string ( ) ,
generation_params : GenerationParams ::default ( ) ,
support_data_types : vec ! [
" int " . to_string ( ) ,
" float " . to_string ( ) ,
" bool " . to_string ( ) ,
" str " . to_string ( ) ,
" list " . to_string ( ) ,
" tuple " . to_string ( ) ,
" set " . to_string ( ) ,
" dict " . to_string ( ) ,
// JSON Schema names (standard)
" integer " . to_string ( ) ,
" number " . to_string ( ) ,
" boolean " . to_string ( ) ,
" string " . to_string ( ) ,
" array " . to_string ( ) ,
" object " . to_string ( ) ,
] ,
}
}
}
/// Configuration for Arch Agent (extends ArchFunctionConfig with different generation params)
#[ derive(Debug, Clone) ]
pub struct ArchAgentConfig {
pub task_prompt : String ,
pub format_prompt : String ,
pub generation_params : GenerationParams ,
pub support_data_types : Vec < String > ,
}
impl Default for ArchAgentConfig {
fn default ( ) -> Self {
let base = ArchFunctionConfig ::default ( ) ;
Self {
task_prompt : base . task_prompt ,
format_prompt : base . format_prompt ,
generation_params : GenerationParams {
temperature : 0.01 ,
top_p : 1.0 ,
top_k : 10 ,
max_tokens : 1024 ,
stop_token_ids : vec ! [ 151645 ] ,
logprobs : Some ( true ) ,
top_logprobs : Some ( 10 ) ,
} ,
support_data_types : base . support_data_types ,
}
}
}
/// Generation parameters for LLM
#[ derive(Debug, Clone, Serialize, Deserialize) ]
pub struct GenerationParams {
pub temperature : f32 ,
pub top_p : f32 ,
pub top_k : u32 ,
pub max_tokens : u32 ,
pub stop_token_ids : Vec < u32 > ,
pub logprobs : Option < bool > ,
pub top_logprobs : Option < u32 > ,
}
impl Default for GenerationParams {
fn default ( ) -> Self {
Self {
temperature : 0.1 ,
top_p : 1.0 ,
top_k : 10 ,
max_tokens : 1024 ,
stop_token_ids : vec ! [ 151645 ] ,
logprobs : Some ( true ) ,
top_logprobs : Some ( 10 ) ,
}
}
}
// ============================================================================
// PARSED MODEL RESPONSE
// ============================================================================
/// Parsed response from the model
#[ derive(Debug, Clone, Serialize, Deserialize, Default) ]
pub struct ParsedModelResponse {
pub raw_response : String ,
pub response : Option < String > ,
pub required_functions : Vec < String > ,
pub clarification : String ,
pub tool_calls : Vec < ToolCall > ,
pub is_valid : bool ,
pub error_message : String ,
}
// ============================================================================
// TOOL CALL VERIFICATION RESULT
// ============================================================================
/// Result of tool call verification
#[ derive(Debug, Clone) ]
pub struct ToolCallVerification {
pub is_valid : bool ,
pub invalid_tool_call : Option < ToolCall > ,
pub error_message : String ,
}
impl Default for ToolCallVerification {
fn default ( ) -> Self {
Self {
is_valid : true ,
invalid_tool_call : None ,
error_message : String ::new ( ) ,
}
}
}
/// Main handler for Arch Function Calling
pub struct ArchFunctionHandler {
pub model_name : String ,
pub config : ArchFunctionConfig ,
pub default_prefix : String ,
pub clarify_prefix : String ,
pub endpoint_url : String ,
pub http_client : reqwest ::Client ,
}
impl ArchFunctionHandler {
/// Creates a new ArchFunctionHandler
pub fn new ( model_name : String , config : ArchFunctionConfig , endpoint_url : String ) -> Self {
use common ::consts ::ARCH_PROVIDER_HINT_HEADER ;
use reqwest ::header ;
// Create custom HTTP client with Arch provider hint header
let mut headers = header ::HeaderMap ::new ( ) ;
headers . insert (
header ::HeaderName ::from_static ( ARCH_PROVIDER_HINT_HEADER ) ,
header ::HeaderValue ::from_str ( & model_name ) . unwrap ( ) ,
) ;
let http_client = reqwest ::ClientBuilder ::new ( )
. default_headers ( headers )
. build ( )
. expect ( " Failed to create HTTP client " ) ;
Self {
model_name ,
config ,
default_prefix : r #" ```json \n { \" " #. to_string ( ) ,
clarify_prefix : r #" ```json \n { \" required_functions \" : " #. to_string ( ) ,
endpoint_url ,
http_client ,
}
}
/// Converts a list of tools into JSON format string
pub fn convert_tools ( & self , tools : & [ Tool ] ) -> Result < String > {
let converted : std ::result ::Result < Vec < String > , serde_json ::Error > = tools
. iter ( )
. map ( | tool | serde_json ::to_string ( & tool . function ) )
. collect ( ) ;
converted
. map ( | v | v . join ( " \\ n " ) )
. map_err ( FunctionCallingError ::from )
}
/// Fixes malformed JSON strings by ensuring proper bracket matching
pub fn fix_json_string ( & self , json_str : & str ) -> Result < String > {
let json_str = json_str . trim ( ) ;
let mut stack : Vec < char > = Vec ::new ( ) ;
let mut fixed_str = String ::new ( ) ;
2025-12-25 21:08:37 -08:00
let matching_bracket : HashMap < char , char > = [ ( ')' , '(' ) , ( '}' , '{' ) , ( ']' , '[' ) ]
2025-11-22 12:55:00 -08:00
. iter ( )
2025-12-25 21:08:37 -08:00
. cloned ( )
2025-11-22 12:55:00 -08:00
. collect ( ) ;
2025-12-25 21:08:37 -08:00
let opening_bracket : HashMap < char , char > =
matching_bracket . iter ( ) . map ( | ( k , v ) | ( * v , * k ) ) . collect ( ) ;
2025-11-22 12:55:00 -08:00
for ch in json_str . chars ( ) {
if ch = = '{' | | ch = = '[' | | ch = = '(' {
stack . push ( ch ) ;
fixed_str . push ( ch ) ;
} else if ch = = '}' | | ch = = ']' | | ch = = ')' {
if let Some ( & last ) = stack . last ( ) {
if matching_bracket . get ( & ch ) = = Some ( & last ) {
stack . pop ( ) ;
fixed_str . push ( ch ) ;
}
// Ignore unmatched closing brackets
}
} else {
fixed_str . push ( ch ) ;
}
}
// Add corresponding closing brackets for unmatched opening brackets
while let Some ( unmatched_opening ) = stack . pop ( ) {
if let Some ( & closing ) = opening_bracket . get ( & unmatched_opening ) {
fixed_str . push ( closing ) ;
}
}
// Try to parse the fixed JSON
match serde_json ::from_str ::< Value > ( & fixed_str ) {
Ok ( val ) = > serde_json ::to_string ( & val ) . map_err ( FunctionCallingError ::from ) ,
Err ( _ ) = > {
// Try replacing single quotes with double quotes
let fixed_str = fixed_str . replace ( '\'' , " \" " ) ;
match serde_json ::from_str ::< Value > ( & fixed_str ) {
Ok ( val ) = > serde_json ::to_string ( & val ) . map_err ( FunctionCallingError ::from ) ,
Err ( e ) = > Err ( FunctionCallingError ::JsonFixError ( format! (
" Failed to fix JSON: {} " ,
e
) ) ) ,
}
}
}
}
/// Parses the model response and extracts tool call information
pub fn parse_model_response ( & self , content : & str ) -> ParsedModelResponse {
let mut response_dict = ParsedModelResponse ::default ( ) ;
// Remove markdown code blocks
let mut content = content . trim ( ) . to_string ( ) ;
if content . starts_with ( " ``` " ) & & content . ends_with ( " ``` " ) {
2025-12-25 21:08:37 -08:00
content = content
. trim_start_matches ( " ``` " )
. trim_end_matches ( " ``` " )
. to_string ( ) ;
2025-11-22 12:55:00 -08:00
if content . starts_with ( " json " ) {
content = content . trim_start_matches ( " json " ) . to_string ( ) ;
}
// Trim again after removing code blocks to eliminate internal whitespace
2025-12-25 21:08:37 -08:00
content = content
. trim_start_matches ( r "\n" )
. trim_end_matches ( r "\n" )
. to_string ( ) ;
2025-11-22 12:55:00 -08:00
content = content . trim ( ) . to_string ( ) ;
// Unescape the quotes: \" -> "
// The model sometimes returns escaped JSON inside markdown blocks
content = content . replace ( r # "\""# , " \" " ) ;
}
// Try to fix JSON if needed
let fixed_content = match self . fix_json_string ( & content ) {
Ok ( fixed ) = > {
response_dict . raw_response = format! ( " ```json \n {} \n ``` " , fixed ) ;
fixed
}
Err ( e ) = > {
response_dict . is_valid = false ;
response_dict . error_message = format! ( " Failed to fix JSON: {} " , e ) ;
return response_dict ;
}
} ;
// Parse the JSON
match serde_json ::from_str ::< Value > ( & fixed_content ) {
Ok ( model_response ) = > {
// Successfully parsed - mark as valid
response_dict . is_valid = true ;
// Extract response field
if let Some ( resp ) = model_response . get ( " response " ) {
if let Some ( resp_str ) = resp . as_str ( ) {
response_dict . response = Some ( resp_str . to_string ( ) ) ;
}
}
// Extract required_functions
if let Some ( funcs ) = model_response . get ( " required_functions " ) {
if let Some ( funcs_arr ) = funcs . as_array ( ) {
response_dict . required_functions = funcs_arr
. iter ( )
. filter_map ( | v | v . as_str ( ) . map ( String ::from ) )
. collect ( ) ;
}
}
// Extract clarification
if let Some ( clarif ) = model_response . get ( " clarification " ) {
if let Some ( clarif_str ) = clarif . as_str ( ) {
response_dict . clarification = clarif_str . to_string ( ) ;
}
}
// Extract tool_calls
if let Some ( tool_calls ) = model_response . get ( " tool_calls " ) {
if let Some ( tool_calls_arr ) = tool_calls . as_array ( ) {
for tool_call_val in tool_calls_arr {
let id = format! ( " call_ {} " , rand ::random ::< u32 > ( ) % 10000 + 1000 ) ;
let name = tool_call_val
. get ( " name " )
. and_then ( | v | v . as_str ( ) )
. unwrap_or ( " " )
. to_string ( ) ;
let arguments = tool_call_val
. get ( " arguments " )
. map ( | v | serde_json ::to_string ( v ) . unwrap_or_default ( ) )
. unwrap_or_default ( ) ;
response_dict . tool_calls . push ( ToolCall {
id ,
call_type : " function " . to_string ( ) ,
function : FunctionCall { name , arguments } ,
} ) ;
}
}
}
}
Err ( e ) = > {
response_dict . is_valid = false ;
response_dict . error_message = format! ( " Failed to parse model response: {} " , e ) ;
}
}
response_dict
}
/// Converts data type from one type to another
pub fn convert_data_type ( & self , value : & Value , target_type : & str ) -> Result < Value > {
match target_type {
// Handle float/number conversions
" float " | " number " = > {
if let Some ( int_val ) = value . as_i64 ( ) {
return Ok ( json! ( int_val as f64 ) ) ;
}
}
// Handle list/array conversions
" list " | " array " = > {
if let Some ( str_val ) = value . as_str ( ) {
// Try to parse as JSON array
if let Ok ( arr ) = serde_json ::from_str ::< Vec < Value > > ( str_val ) {
return Ok ( json! ( arr ) ) ;
}
}
}
// Handle str/string conversions
" str " | " string " = > {
if ! value . is_string ( ) {
return Ok ( json! ( value . to_string ( ) ) ) ;
}
}
_ = > { }
}
Ok ( value . clone ( ) )
}
/// Helper method to check if a value matches the expected type
fn check_value_type ( & self , value : & Value , target_type : & str ) -> bool {
match target_type {
2025-12-25 21:08:37 -08:00
" int " | " integer " = > value . is_i64 ( ) | | value . is_u64 ( ) ,
2025-11-22 12:55:00 -08:00
" float " | " number " = > value . is_f64 ( ) | | value . is_i64 ( ) | | value . is_u64 ( ) ,
2025-12-25 21:08:37 -08:00
" bool " | " boolean " = > value . is_boolean ( ) ,
" str " | " string " = > value . is_string ( ) ,
" list " | " array " = > value . is_array ( ) ,
" dict " | " object " = > value . is_object ( ) ,
2025-11-22 12:55:00 -08:00
_ = > true ,
}
}
/// Helper method to validate and potentially convert a parameter value to match the target type
/// Returns Ok(true) if the value is valid (either originally or after conversion)
/// Returns Ok(false) if the value cannot be converted to the target type
fn validate_or_convert_parameter (
& self ,
param_value : & Value ,
target_type : & str ,
) -> Result < bool > {
// First check: Is it already the correct type?
if self . check_value_type ( param_value , target_type ) {
return Ok ( true ) ;
}
// Try to convert
let converted = self . convert_data_type ( param_value , target_type ) ? ;
// Second check: Is it the correct type after conversion?
Ok ( self . check_value_type ( & converted , target_type ) )
}
/// Verifies the validity of extracted tool calls against the provided tools
pub fn verify_tool_calls (
& self ,
tools : & [ Tool ] ,
tool_calls : & [ ToolCall ] ,
) -> ToolCallVerification {
let mut verification = ToolCallVerification ::default ( ) ;
// Build a map of function name to parameters
let mut functions : HashMap < String , & Value > = HashMap ::new ( ) ;
for tool in tools {
functions . insert ( tool . function . name . clone ( ) , & tool . function . parameters ) ;
}
for tool_call in tool_calls {
if ! verification . is_valid {
break ;
}
let func_name = & tool_call . function . name ;
// Parse arguments as JSON
2025-12-25 21:08:37 -08:00
let func_args : HashMap < String , Value > =
match serde_json ::from_str ( & tool_call . function . arguments ) {
Ok ( args ) = > args ,
Err ( e ) = > {
verification . is_valid = false ;
verification . invalid_tool_call = Some ( tool_call . clone ( ) ) ;
verification . error_message = format! (
" Failed to parse arguments for function '{}': {} " ,
func_name , e
) ;
break ;
}
} ;
2025-11-22 12:55:00 -08:00
// Check if function is available
if let Some ( function_params ) = functions . get ( func_name ) {
// Check if all required parameters are present
if let Some ( required ) = function_params . get ( " required " ) {
if let Some ( required_arr ) = required . as_array ( ) {
for required_param in required_arr {
if let Some ( param_name ) = required_param . as_str ( ) {
if ! func_args . contains_key ( param_name ) {
verification . is_valid = false ;
verification . invalid_tool_call = Some ( tool_call . clone ( ) ) ;
verification . error_message = format! (
" `{}` is required by the function `{}` but not found in the tool call! " ,
param_name , func_name
) ;
break ;
}
}
}
}
}
// Verify the data type of each parameter
if let Some ( properties ) = function_params . get ( " properties " ) {
if let Some ( properties_obj ) = properties . as_object ( ) {
for ( param_name , param_value ) in & func_args {
if let Some ( param_schema ) = properties_obj . get ( param_name ) {
2025-12-25 21:08:37 -08:00
if let Some ( target_type ) =
param_schema . get ( " type " ) . and_then ( | v | v . as_str ( ) )
{
if self
. config
. support_data_types
. contains ( & target_type . to_string ( ) )
{
2025-11-22 12:55:00 -08:00
// Validate data type using helper method
2025-12-25 21:08:37 -08:00
match self
. validate_or_convert_parameter ( param_value , target_type )
{
2025-11-22 12:55:00 -08:00
Ok ( is_valid ) = > {
if ! is_valid {
verification . is_valid = false ;
2025-12-25 21:08:37 -08:00
verification . invalid_tool_call =
Some ( tool_call . clone ( ) ) ;
2025-11-22 12:55:00 -08:00
verification . error_message = format! (
" Parameter `{}` is expected to have the data type `{}`, got incompatible type. " ,
param_name , target_type
) ;
break ;
}
}
Err ( _ ) = > {
verification . is_valid = false ;
2025-12-25 21:08:37 -08:00
verification . invalid_tool_call =
Some ( tool_call . clone ( ) ) ;
2025-11-22 12:55:00 -08:00
verification . error_message = format! (
" Parameter `{}` is expected to have the data type `{}`, got incompatible type. " ,
param_name , target_type
) ;
break ;
}
}
} else {
verification . is_valid = false ;
verification . invalid_tool_call = Some ( tool_call . clone ( ) ) ;
2025-12-25 21:08:37 -08:00
verification . error_message = format! (
" Data type `{}` is not supported. " ,
target_type
) ;
2025-11-22 12:55:00 -08:00
break ;
}
}
} else {
verification . is_valid = false ;
verification . invalid_tool_call = Some ( tool_call . clone ( ) ) ;
verification . error_message = format! (
" Parameter `{}` is not defined in the function `{}`. " ,
param_name , func_name
) ;
break ;
}
}
}
}
} else {
verification . is_valid = false ;
verification . invalid_tool_call = Some ( tool_call . clone ( ) ) ;
verification . error_message = format! ( " {} is not available! " , func_name ) ;
}
}
verification
}
/// Formats the system prompt with tools
pub fn format_system_prompt ( & self , tools : & [ Tool ] ) -> Result < String > {
let tools_str = self . convert_tools ( tools ) ? ;
2025-12-25 21:08:37 -08:00
let system_prompt =
self . config . task_prompt . replace ( " {tools} " , & tools_str ) + & self . config . format_prompt ;
2025-11-22 12:55:00 -08:00
Ok ( system_prompt )
}
/// Processes messages and formats them appropriately for the model
pub fn process_messages (
& self ,
messages : & [ Message ] ,
tools : Option < & [ Tool ] > ,
extra_instruction : Option < & str > ,
max_tokens : usize ,
metadata : Option < & HashMap < String , Value > > ,
) -> Result < Vec < Message > > {
let mut processed_messages = Vec ::new ( ) ;
// Add system message with tools if provided
if let Some ( tools ) = tools {
let system_prompt = self . format_system_prompt ( tools ) ? ;
processed_messages . push ( Message {
role : Role ::System ,
2026-01-16 16:24:03 -08:00
content : Some ( MessageContent ::Text ( system_prompt ) ) ,
2025-11-22 12:55:00 -08:00
name : None ,
tool_calls : None ,
tool_call_id : None ,
} ) ;
}
// Process each message
for ( idx , message ) in messages . iter ( ) . enumerate ( ) {
let mut role = message . role . clone ( ) ;
let mut content = match & message . content {
2026-01-16 16:24:03 -08:00
Some ( MessageContent ::Text ( text ) ) = > text . clone ( ) ,
Some ( MessageContent ::Parts ( _ ) ) = > String ::new ( ) ,
None = > String ::new ( ) ,
2025-11-22 12:55:00 -08:00
} ;
// Handle tool calls
if let Some ( tool_calls ) = & message . tool_calls {
if ! tool_calls . is_empty ( ) {
role = Role ::Assistant ;
let tool_call_json = serde_json ::to_string ( & tool_calls [ 0 ] . function ) ? ;
content = format! ( " <tool_call> \n {} \n </tool_call> " , tool_call_json ) ;
}
} else if role = = Role ::Tool {
role = Role ::User ;
// Check if we should optimize context window
let optimize_context = metadata
. and_then ( | m | m . get ( " optimize_context_window " ) )
. and_then ( | v | v . as_str ( ) )
. map ( | s | s . to_lowercase ( ) = = " true " )
. unwrap_or ( false ) ;
if optimize_context {
content = " <tool_response> \n \n </tool_response> " . to_string ( ) ;
} else {
// Get the tool call from previous message
if idx > 0 {
2026-01-16 16:24:03 -08:00
if let Some ( MessageContent ::Text ( prev_content ) ) = & messages [ idx - 1 ] . content
{
2025-11-22 12:55:00 -08:00
let mut tool_call_msg = prev_content . clone ( ) ;
// Strip markdown code blocks
if tool_call_msg . starts_with ( " ``` " ) & & tool_call_msg . ends_with ( " ``` " ) {
2025-12-25 21:08:37 -08:00
tool_call_msg = tool_call_msg
. trim_start_matches ( " ``` " )
. trim_end_matches ( " ``` " )
. trim ( )
. to_string ( ) ;
2025-11-22 12:55:00 -08:00
if tool_call_msg . starts_with ( " json " ) {
2025-12-25 21:08:37 -08:00
tool_call_msg =
tool_call_msg . trim_start_matches ( " json " ) . trim ( ) . to_string ( ) ;
2025-11-22 12:55:00 -08:00
}
}
// Extract function name
if let Ok ( parsed ) = serde_json ::from_str ::< Value > ( & tool_call_msg ) {
2025-12-25 21:08:37 -08:00
if let Some ( tool_calls_arr ) =
parsed . get ( " tool_calls " ) . and_then ( | v | v . as_array ( ) )
{
2025-11-22 12:55:00 -08:00
if let Some ( first_tool_call ) = tool_calls_arr . first ( ) {
let func_name = first_tool_call
. get ( " name " )
. and_then ( | v | v . as_str ( ) )
. unwrap_or ( " no_name " ) ;
let tool_response = json! ( {
" name " : func_name ,
" result " : content ,
} ) ;
2025-12-25 21:08:37 -08:00
content = format! (
" <tool_response> \n {} \n </tool_response> " ,
serde_json ::to_string ( & tool_response ) ?
) ;
2025-11-22 12:55:00 -08:00
}
}
}
}
}
}
}
processed_messages . push ( Message {
role ,
2026-01-16 16:24:03 -08:00
content : Some ( MessageContent ::Text ( content ) ) ,
2025-11-22 12:55:00 -08:00
name : message . name . clone ( ) ,
tool_calls : None ,
tool_call_id : None ,
} ) ;
}
// Ensure last message is from user
if let Some ( last ) = processed_messages . last ( ) {
if last . role ! = Role ::User {
return Err ( FunctionCallingError ::InvalidModelResponse (
" Last message must be from user " . to_string ( ) ,
) ) ;
}
}
// Add extra instruction if provided
if let Some ( instruction ) = extra_instruction {
if let Some ( last ) = processed_messages . last_mut ( ) {
2026-01-16 16:24:03 -08:00
if let Some ( MessageContent ::Text ( content ) ) = & mut last . content {
2025-12-25 21:08:37 -08:00
content . push ( '\n' ) ;
2025-11-22 12:55:00 -08:00
content . push_str ( instruction ) ;
}
}
}
// Truncate messages if they exceed max_tokens
let processed_messages = self . truncate_messages ( processed_messages , max_tokens ) ;
Ok ( processed_messages )
}
/// Truncates messages to fit within max_tokens limit
fn truncate_messages ( & self , messages : Vec < Message > , max_tokens : usize ) -> Vec < Message > {
let mut num_tokens = 0 ;
let mut conversation_idx = 0 ;
// Keep system message if present
if let Some ( first ) = messages . first ( ) {
if first . role = = Role ::System {
2026-01-16 16:24:03 -08:00
if let Some ( MessageContent ::Text ( content ) ) = & first . content {
2025-11-22 12:55:00 -08:00
num_tokens + = content . len ( ) / 4 ; // Approximate 4 chars per token
}
conversation_idx = 1 ;
}
}
// Calculate from the end backwards
// Start with message_idx pointing past the end (will be used if no truncation needed)
let mut message_idx = messages . len ( ) ;
for i in ( conversation_idx .. messages . len ( ) ) . rev ( ) {
2026-01-16 16:24:03 -08:00
if let Some ( MessageContent ::Text ( content ) ) = & messages [ i ] . content {
2025-11-22 12:55:00 -08:00
num_tokens + = content . len ( ) / 4 ;
2025-12-25 21:08:37 -08:00
if num_tokens > = max_tokens & & messages [ i ] . role = = Role ::User {
// Set message_idx to current position and break
// This matches Python's behavior where message_idx is set before break
message_idx = i ;
break ;
2025-11-22 12:55:00 -08:00
}
}
// Only update message_idx if we haven't hit the token limit yet
// This ensures message_idx points to where truncation should start
if num_tokens < max_tokens {
message_idx = i ;
}
}
// Return system message + truncated conversation
let mut result = Vec ::new ( ) ;
if conversation_idx > 0 {
result . push ( messages [ 0 ] . clone ( ) ) ;
}
result . extend_from_slice ( & messages [ message_idx .. ] ) ;
result
}
/// Prefills a message by adding an assistant message with the prefix
pub fn prefill_message ( & self , mut messages : Vec < Message > , prefill : & str ) -> Vec < Message > {
messages . push ( Message {
role : Role ::Assistant ,
2026-01-16 16:24:03 -08:00
content : Some ( MessageContent ::Text ( prefill . to_string ( ) ) ) ,
2025-11-22 12:55:00 -08:00
name : None ,
tool_calls : None ,
tool_call_id : None ,
} ) ;
messages
}
/// Helper to create a request with VLLM-specific parameters
2025-12-25 21:08:37 -08:00
fn create_request_with_extra_body (
& self ,
messages : Vec < Message > ,
stream : bool ,
) -> ChatCompletionsRequest {
2025-11-22 12:55:00 -08:00
ChatCompletionsRequest {
model : self . model_name . clone ( ) ,
messages ,
temperature : Some ( self . config . generation_params . temperature ) ,
top_p : Some ( self . config . generation_params . top_p ) ,
max_tokens : Some ( self . config . generation_params . max_tokens ) ,
stream : Some ( stream ) ,
logprobs : self . config . generation_params . logprobs ,
top_logprobs : self . config . generation_params . top_logprobs ,
// VLLM-specific parameters
continue_final_message : Some ( true ) ,
add_generation_prompt : Some ( false ) ,
top_k : Some ( self . config . generation_params . top_k ) ,
stop_token_ids : if ! self . config . generation_params . stop_token_ids . is_empty ( ) {
Some ( self . config . generation_params . stop_token_ids . clone ( ) )
} else {
None
} ,
.. Default ::default ( )
}
}
/// Makes a streaming request and returns the SSE event stream
2025-12-25 21:08:37 -08:00
async fn make_streaming_request (
& self ,
request : ChatCompletionsRequest ,
) -> Result <
std ::pin ::Pin < Box < dyn futures ::Stream < Item = std ::result ::Result < Value , String > > + Send > > ,
> {
let request_body = serde_json ::to_string ( & request ) . map_err ( | e | {
FunctionCallingError ::InvalidModelResponse ( format! (
" Failed to serialize request: {} " ,
e
) )
} ) ? ;
let response = self
. http_client
2025-11-22 12:55:00 -08:00
. post ( & self . endpoint_url )
. header ( " Content-Type " , " application/json " )
. body ( request_body )
. send ( )
. await
2025-12-25 21:08:37 -08:00
. map_err ( FunctionCallingError ::HttpError ) ? ;
2025-11-22 12:55:00 -08:00
if ! response . status ( ) . is_success ( ) {
let status = response . status ( ) ;
2025-12-25 21:08:37 -08:00
let error_text = response
. text ( )
. await
. unwrap_or_else ( | _ | " Unknown error " . to_string ( ) ) ;
return Err ( FunctionCallingError ::InvalidModelResponse ( format! (
" HTTP error {}: {} " ,
status , error_text
) ) ) ;
2025-11-22 12:55:00 -08:00
}
// Parse SSE stream
let stream = response . bytes_stream ( ) . eventsource ( ) ;
let parsed_stream = stream . filter_map ( | event_result | async move {
match event_result {
Ok ( event ) = > {
// Skip [DONE] sentinel
if event . data = = " [DONE] " {
return None ;
}
// Parse JSON
match serde_json ::from_str ::< Value > ( & event . data ) {
Ok ( json ) = > Some ( Ok ( json ) ) ,
Err ( e ) = > Some ( Err ( format! ( " JSON parse error: {} " , e ) ) ) ,
}
}
Err ( e ) = > Some ( Err ( format! ( " SSE stream error: {} " , e ) ) ) ,
}
} ) ;
Ok ( Box ::pin ( parsed_stream ) )
}
/// Makes a non-streaming request and returns the response
2025-12-25 21:08:37 -08:00
async fn make_non_streaming_request (
& self ,
request : ChatCompletionsRequest ,
) -> Result < ChatCompletionsResponse > {
let request_body = serde_json ::to_string ( & request ) . map_err ( | e | {
FunctionCallingError ::InvalidModelResponse ( format! (
" Failed to serialize request: {} " ,
e
) )
} ) ? ;
let response = self
. http_client
2025-11-22 12:55:00 -08:00
. post ( & self . endpoint_url )
. header ( " Content-Type " , " application/json " )
. body ( request_body )
. send ( )
. await
2025-12-25 21:08:37 -08:00
. map_err ( FunctionCallingError ::HttpError ) ? ;
2025-11-22 12:55:00 -08:00
if ! response . status ( ) . is_success ( ) {
let status = response . status ( ) ;
2025-12-25 21:08:37 -08:00
let error_text = response
. text ( )
. await
. unwrap_or_else ( | _ | " Unknown error " . to_string ( ) ) ;
return Err ( FunctionCallingError ::InvalidModelResponse ( format! (
" HTTP error {}: {} " ,
status , error_text
) ) ) ;
2025-11-22 12:55:00 -08:00
}
2025-12-25 21:08:37 -08:00
let response_text = response
. text ( )
. await
. map_err ( FunctionCallingError ::HttpError ) ? ;
2025-11-22 12:55:00 -08:00
2025-12-25 21:08:37 -08:00
serde_json ::from_str ( & response_text ) . map_err ( FunctionCallingError ::JsonParseError )
2025-11-22 12:55:00 -08:00
}
pub async fn function_calling_chat (
& self ,
request : ChatCompletionsRequest ,
) -> Result < ChatCompletionsResponse > {
2025-12-25 21:08:37 -08:00
use tracing ::{ error , info } ;
2025-11-22 12:55:00 -08:00
info! ( " [Arch-Function] - ChatCompletion " ) ;
let messages = self . process_messages (
& request . messages ,
request . tools . as_deref ( ) ,
None ,
self . config . generation_params . max_tokens as usize ,
request . metadata . as_ref ( ) ,
) ? ;
2025-12-25 21:08:37 -08:00
info! (
" [request to arch-fc]: model: {}, messages count: {} " ,
self . model_name ,
messages . len ( )
) ;
2025-11-22 12:55:00 -08:00
2025-12-25 21:08:37 -08:00
let use_agent_orchestrator = request
. metadata
2025-11-22 12:55:00 -08:00
. as_ref ( )
. and_then ( | m | m . get ( " use_agent_orchestrator " ) )
. and_then ( | v | v . as_bool ( ) )
. unwrap_or ( false ) ;
let prefilled_messages = self . prefill_message ( messages . clone ( ) , & self . default_prefix ) ;
// Create request with extra_body parameters
let stream_request = self . create_request_with_extra_body ( prefilled_messages . clone ( ) , true ) ;
let mut stream = self . make_streaming_request ( stream_request ) . await ? ;
let mut model_response = String ::new ( ) ;
if use_agent_orchestrator {
while let Some ( chunk_result ) = stream . next ( ) . await {
2025-12-25 21:08:37 -08:00
let chunk = chunk_result . map_err ( FunctionCallingError ::InvalidModelResponse ) ? ;
2025-11-22 12:55:00 -08:00
// Extract content from JSON response
if let Some ( choices ) = chunk . get ( " choices " ) . and_then ( | v | v . as_array ( ) ) {
if let Some ( choice ) = choices . first ( ) {
2025-12-25 21:08:37 -08:00
if let Some ( content ) = choice
. get ( " delta " )
2025-11-22 12:55:00 -08:00
. and_then ( | d | d . get ( " content " ) )
2025-12-25 21:08:37 -08:00
. and_then ( | c | c . as_str ( ) )
{
2025-11-22 12:55:00 -08:00
model_response . push_str ( content ) ;
}
}
}
}
info! ( " [Agent Orchestrator]: response received " ) ;
2025-12-25 21:08:37 -08:00
} else if let Some ( tools ) = request . tools . as_ref ( ) {
let mut hallucination_state = HallucinationState ::new ( tools ) ;
let mut has_tool_calls = None ;
let mut has_hallucination = false ;
2025-11-22 12:55:00 -08:00
2025-12-25 21:08:37 -08:00
while let Some ( chunk_result ) = stream . next ( ) . await {
let chunk = chunk_result . map_err ( FunctionCallingError ::InvalidModelResponse ) ? ;
// Extract content and logprobs from JSON response
if let Some ( choices ) = chunk . get ( " choices " ) . and_then ( | v | v . as_array ( ) ) {
if let Some ( choice ) = choices . first ( ) {
if let Some ( content ) = choice
. get ( " delta " )
. and_then ( | d | d . get ( " content " ) )
. and_then ( | c | c . as_str ( ) )
{
// Extract logprobs
let logprobs : Vec < f64 > = choice
. get ( " logprobs " )
. and_then ( | lp | lp . get ( " content " ) )
. and_then ( | c | c . as_array ( ) )
. and_then ( | arr | arr . first ( ) )
. and_then ( | token | token . get ( " top_logprobs " ) )
. and_then ( | tlp | tlp . as_array ( ) )
. map ( | arr | {
arr . iter ( )
. filter_map ( | v | v . get ( " logprob " ) . and_then ( | lp | lp . as_f64 ( ) ) )
. collect ( )
} )
. unwrap_or_default ( ) ;
if hallucination_state
. append_and_check_token_hallucination ( content . to_string ( ) , logprobs )
{
has_hallucination = true ;
break ;
}
if hallucination_state . tokens . len ( ) > 5 & & has_tool_calls . is_none ( ) {
let collected_content = hallucination_state . tokens . join ( " " ) ;
has_tool_calls = Some ( collected_content . contains ( " tool_calls " ) ) ;
2025-11-22 12:55:00 -08:00
}
}
}
}
2025-12-25 21:08:37 -08:00
}
2025-11-22 12:55:00 -08:00
2025-12-25 21:08:37 -08:00
if has_tool_calls = = Some ( true ) & & has_hallucination {
info! ( " [Hallucination]: {} " , hallucination_state . error_message ) ;
2025-11-22 12:55:00 -08:00
2025-12-25 21:08:37 -08:00
let clarify_messages = self . prefill_message ( messages . clone ( ) , & self . clarify_prefix ) ;
let clarify_request = self . create_request_with_extra_body ( clarify_messages , false ) ;
2025-11-22 12:55:00 -08:00
2025-12-25 21:08:37 -08:00
let retry_response = self . make_non_streaming_request ( clarify_request ) . await ? ;
2025-11-22 12:55:00 -08:00
2025-12-25 21:08:37 -08:00
if let Some ( choice ) = retry_response . choices . first ( ) {
if let Some ( content ) = & choice . message . content {
model_response = content . clone ( ) ;
2025-11-22 12:55:00 -08:00
}
}
} else {
2025-12-25 21:08:37 -08:00
model_response = hallucination_state . tokens . join ( " " ) ;
}
} else {
while let Some ( chunk_result ) = stream . next ( ) . await {
let chunk = chunk_result . map_err ( FunctionCallingError ::InvalidModelResponse ) ? ;
if let Some ( choices ) = chunk . get ( " choices " ) . and_then ( | v | v . as_array ( ) ) {
if let Some ( choice ) = choices . first ( ) {
if let Some ( content ) = choice
. get ( " delta " )
. and_then ( | d | d . get ( " content " ) )
. and_then ( | c | c . as_str ( ) )
{
model_response . push_str ( content ) ;
2025-11-22 12:55:00 -08:00
}
}
}
}
}
let response_dict = self . parse_model_response ( & model_response ) ;
2025-12-25 21:08:37 -08:00
info! (
" [arch-fc]: raw model response: {} " ,
response_dict . raw_response
) ;
2025-11-22 12:55:00 -08:00
// General model response (no intent matched - should route to default target)
2025-12-25 21:08:37 -08:00
let model_message = if response_dict
. response
. as_ref ( )
. is_some_and ( | s | ! s . is_empty ( ) )
{
2025-11-22 12:55:00 -08:00
// When arch-fc returns a "response" field, it means no intent was matched
// Return empty content and empty tool_calls so prompt_gateway routes to default target
ResponseMessage {
role : Role ::Assistant ,
content : Some ( String ::new ( ) ) ,
refusal : None ,
annotations : None ,
audio : None ,
function_call : None ,
tool_calls : None ,
}
} else if ! response_dict . required_functions . is_empty ( ) {
if ! use_agent_orchestrator {
ResponseMessage {
role : Role ::Assistant ,
content : Some ( response_dict . clarification . clone ( ) ) ,
refusal : None ,
annotations : None ,
audio : None ,
function_call : None ,
tool_calls : None ,
}
} else {
ResponseMessage {
role : Role ::Assistant ,
content : Some ( String ::new ( ) ) ,
refusal : None ,
annotations : None ,
audio : None ,
function_call : None ,
tool_calls : None ,
}
}
} else if ! response_dict . tool_calls . is_empty ( ) {
if response_dict . is_valid {
if ! use_agent_orchestrator {
if let Some ( tools ) = request . tools . as_ref ( ) {
let verification = self . verify_tool_calls ( tools , & response_dict . tool_calls ) ;
if verification . is_valid {
2025-12-25 21:08:37 -08:00
info! (
" [Tool calls]: {:?} " ,
response_dict
. tool_calls
. iter ( )
2025-11-22 12:55:00 -08:00
. map ( | tc | & tc . function )
. collect ::< Vec < _ > > ( )
) ;
ResponseMessage {
role : Role ::Assistant ,
content : Some ( String ::new ( ) ) ,
refusal : None ,
annotations : None ,
audio : None ,
function_call : None ,
tool_calls : Some ( response_dict . tool_calls . clone ( ) ) ,
}
} else {
error! ( " Invalid tool call - {} " , verification . error_message ) ;
ResponseMessage {
role : Role ::Assistant ,
content : Some ( String ::new ( ) ) ,
refusal : None ,
annotations : None ,
audio : None ,
function_call : None ,
tool_calls : None ,
}
}
} else {
error! ( " Tool calls present but no tools provided in request " ) ;
ResponseMessage {
role : Role ::Assistant ,
content : Some ( String ::new ( ) ) ,
refusal : None ,
annotations : None ,
audio : None ,
function_call : None ,
tool_calls : None ,
}
}
} else {
2025-12-25 21:08:37 -08:00
info! (
" [Tool calls]: {:?} " ,
response_dict
. tool_calls
. iter ( )
2025-11-22 12:55:00 -08:00
. map ( | tc | & tc . function )
. collect ::< Vec < _ > > ( )
) ;
ResponseMessage {
role : Role ::Assistant ,
content : Some ( String ::new ( ) ) ,
refusal : None ,
annotations : None ,
audio : None ,
function_call : None ,
tool_calls : Some ( response_dict . tool_calls . clone ( ) ) ,
}
}
} else {
2025-12-25 21:08:37 -08:00
error! (
" Invalid tool calls in response: {} " ,
response_dict . error_message
) ;
2025-11-22 12:55:00 -08:00
ResponseMessage {
role : Role ::Assistant ,
content : Some ( String ::new ( ) ) ,
refusal : None ,
annotations : None ,
audio : None ,
function_call : None ,
tool_calls : None ,
}
}
} else {
error! ( " Invalid model response - {} " , model_response ) ;
ResponseMessage {
role : Role ::Assistant ,
content : Some ( String ::new ( ) ) ,
refusal : None ,
annotations : None ,
audio : None ,
function_call : None ,
tool_calls : None ,
}
} ;
// Create metadata with the raw model response
let mut metadata = HashMap ::new ( ) ;
metadata . insert (
" x-arch-fc-model-response " . to_string ( ) ,
serde_json ::to_value ( & response_dict . raw_response )
. unwrap_or_else ( | _ | Value ::String ( response_dict . raw_response . clone ( ) ) ) ,
) ;
let chat_completion_response = ChatCompletionsResponse {
id : format ! ( " chatcmpl-{} " , uuid ::Uuid ::new_v4 ( ) ) ,
object : Some ( " chat.completion " . to_string ( ) ) ,
created : chrono ::Utc ::now ( ) . timestamp ( ) as u64 ,
model : request . model . clone ( ) ,
choices : vec ! [ Choice {
index : 0 ,
message : model_message ,
finish_reason : Some ( FinishReason ::Stop ) ,
logprobs : None ,
} ] ,
usage : Usage {
prompt_tokens : 0 ,
completion_tokens : 0 ,
total_tokens : 0 ,
prompt_tokens_details : None ,
completion_tokens_details : None ,
} ,
system_fingerprint : None ,
service_tier : None ,
metadata : Some ( metadata ) ,
} ;
info! ( " [response arch-fc]: {:?} " , chat_completion_response ) ;
Ok ( chat_completion_response )
}
}
// ============================================================================
// ARCH AGENT HANDLER
// ============================================================================
/// Handler for Arch Agent (extends ArchFunctionHandler with specialized behavior)
pub struct ArchAgentHandler {
pub function_handler : ArchFunctionHandler ,
}
impl ArchAgentHandler {
/// Creates a new ArchAgentHandler
pub fn new ( model_name : String , endpoint_url : String ) -> Self {
let config = ArchAgentConfig ::default ( ) ;
Self {
function_handler : ArchFunctionHandler ::new (
model_name ,
ArchFunctionConfig {
task_prompt : config . task_prompt ,
format_prompt : config . format_prompt ,
generation_params : GenerationParams {
temperature : config . generation_params . temperature ,
top_p : config . generation_params . top_p ,
top_k : config . generation_params . top_k ,
max_tokens : config . generation_params . max_tokens ,
stop_token_ids : config . generation_params . stop_token_ids ,
logprobs : config . generation_params . logprobs ,
top_logprobs : config . generation_params . top_logprobs ,
} ,
support_data_types : config . support_data_types ,
} ,
endpoint_url ,
) ,
}
}
/// Converts tools with special handling for empty parameters
/// This is the key difference from ArchFunctionHandler
pub fn convert_tools ( & self , tools : & [ Tool ] ) -> Result < String > {
let mut converted = Vec ::new ( ) ;
for tool in tools {
let mut tool_copy = tool . clone ( ) ;
// Delete parameters key if its empty
if let Some ( props ) = tool_copy . function . parameters . get ( " properties " ) {
if props . is_object ( ) & & props . as_object ( ) . unwrap ( ) . is_empty ( ) {
// Create new parameters without properties
if let Some ( params_obj ) = tool_copy . function . parameters . as_object_mut ( ) {
params_obj . remove ( " properties " ) ;
}
}
}
converted . push ( serde_json ::to_string ( & tool_copy . function ) ? ) ;
}
Ok ( converted . join ( " \n " ) )
}
}
// ============================================================================
// HTTP HANDLER FOR FUNCTION CALLING ENDPOINT
// ============================================================================
fn full < T : Into < Bytes > > ( chunk : T ) -> BoxBody < Bytes , hyper ::Error > {
Full ::new ( chunk . into ( ) )
. map_err ( | never | match never { } )
. boxed ( )
}
pub async fn function_calling_chat_handler (
req : Request < Incoming > ,
llm_provider_url : String ,
) -> std ::result ::Result < Response < BoxBody < Bytes , hyper ::Error > > , hyper ::Error > {
use hermesllm ::apis ::openai ::ChatCompletionsRequest ;
let whole_body = req . collect ( ) . await ? . to_bytes ( ) ;
// Parse as JSON Value first to modify it
let mut body_json : Value = match serde_json ::from_slice ( & whole_body ) {
Ok ( json ) = > json ,
Err ( e ) = > {
error! ( " Failed to parse request body as JSON: {} " , e ) ;
let mut response = Response ::new ( full (
serde_json ::json! ( {
" error " : format ! ( " Invalid request body: {} " , e )
2025-12-25 21:08:37 -08:00
} )
. to_string ( ) ,
2025-11-22 12:55:00 -08:00
) ) ;
* response . status_mut ( ) = StatusCode ::BAD_REQUEST ;
2025-12-25 21:08:37 -08:00
response
. headers_mut ( )
. insert ( " Content-Type " , " application/json " . parse ( ) . unwrap ( ) ) ;
2025-11-22 12:55:00 -08:00
return Ok ( response ) ;
}
} ;
// Add "model": "Arch-Function" to the request
if let Some ( obj ) = body_json . as_object_mut ( ) {
obj . insert ( " model " . to_string ( ) , ARCH_FUNCTION_MODEL_NAME . into ( ) ) ;
}
// Parse as ChatCompletionsRequest
let chat_request : ChatCompletionsRequest = match serde_json ::from_value ( body_json ) {
Ok ( req ) = > {
2025-12-25 21:08:37 -08:00
info! (
" [request body]: {} " ,
serde_json ::to_string ( & req ) . unwrap_or_default ( )
) ;
2025-11-22 12:55:00 -08:00
req
2025-12-25 21:08:37 -08:00
}
2025-11-22 12:55:00 -08:00
Err ( e ) = > {
error! ( " Failed to parse request body: {} " , e ) ;
let mut response = Response ::new ( full (
serde_json ::json! ( {
" error " : format ! ( " Invalid request body: {} " , e )
2025-12-25 21:08:37 -08:00
} )
. to_string ( ) ,
2025-11-22 12:55:00 -08:00
) ) ;
* response . status_mut ( ) = StatusCode ::BAD_REQUEST ;
2025-12-25 21:08:37 -08:00
response
. headers_mut ( )
. insert ( " Content-Type " , " application/json " . parse ( ) . unwrap ( ) ) ;
2025-11-22 12:55:00 -08:00
return Ok ( response ) ;
}
} ;
// Determine which handler to use based on metadata
2025-12-25 21:08:37 -08:00
let use_agent_orchestrator = chat_request
. metadata
2025-11-22 12:55:00 -08:00
. as_ref ( )
. and_then ( | m | m . get ( " use_agent_orchestrator " ) )
. and_then ( | v | v . as_bool ( ) )
. unwrap_or ( false ) ;
info! ( " Use agent orchestrator: {} " , use_agent_orchestrator ) ;
// Create the appropriate handler
let handler_name = if use_agent_orchestrator {
" Arch-Agent "
} else {
" Arch-Function "
} ;
// Call the handler
let final_response = if use_agent_orchestrator {
let handler = ArchAgentHandler ::new (
ARCH_FUNCTION_MODEL_NAME . to_string ( ) ,
llm_provider_url . clone ( ) ,
) ;
2025-12-25 21:08:37 -08:00
handler
. function_handler
. function_calling_chat ( chat_request )
. await
2025-11-22 12:55:00 -08:00
} else {
let handler = ArchFunctionHandler ::new (
ARCH_FUNCTION_MODEL_NAME . to_string ( ) ,
ArchFunctionConfig ::default ( ) ,
llm_provider_url . clone ( ) ,
) ;
handler . function_calling_chat ( chat_request ) . await
} ;
match final_response {
Ok ( response_data ) = > {
let response_json = serde_json ::to_string ( & response_data ) . unwrap_or_else ( | e | {
error! ( " Failed to serialize response: {} " , e ) ;
serde_json ::json! ( { " error " : " Failed to serialize response " } ) . to_string ( )
} ) ;
let mut response = Response ::new ( full ( response_json ) ) ;
* response . status_mut ( ) = StatusCode ::OK ;
2025-12-25 21:08:37 -08:00
response
. headers_mut ( )
. insert ( " Content-Type " , " application/json " . parse ( ) . unwrap ( ) ) ;
2025-11-22 12:55:00 -08:00
Ok ( response )
}
Err ( e ) = > {
error! ( " [{}] - Error in function calling: {} " , handler_name , e ) ;
let error_response = serde_json ::json! ( {
" error " : format ! ( " [{}] - Error in function calling: {} " , handler_name , e )
} ) ;
let mut response = Response ::new ( full ( error_response . to_string ( ) ) ) ;
* response . status_mut ( ) = StatusCode ::INTERNAL_SERVER_ERROR ;
2025-12-25 21:08:37 -08:00
response
. headers_mut ( )
. insert ( " Content-Type " , " application/json " . parse ( ) . unwrap ( ) ) ;
2025-11-22 12:55:00 -08:00
Ok ( response )
}
}
}
// ============================================================================
// TESTS
// ============================================================================
#[ cfg(test) ]
mod tests {
use super ::* ;
#[ test ]
fn test_arch_function_config_default ( ) {
let config = ArchFunctionConfig ::default ( ) ;
assert! ( config . task_prompt . contains ( " helpful assistant " ) ) ;
assert! ( config . format_prompt . contains ( " JSON formats " ) ) ;
assert_eq! ( config . generation_params . temperature , 0.1 ) ;
assert_eq! ( config . support_data_types . len ( ) , 14 ) ; // 8 Python-style + 6 JSON Schema names
// Verify prompt formatting for literal escaped newlines ("\\n") instead of actual newline chars
// The user requirement changed prompts to display "\\n" sequences literally.
assert! ( config . task_prompt . contains ( " \\ n \\ nYou are provided " ) ) ;
assert! ( config . task_prompt . contains ( " </tools> \\ n \\ n " ) ) ;
// Format prompt should contain literal escaped newlines and proper JSON examples
2025-12-25 21:08:37 -08:00
assert! ( config
. format_prompt
. contains ( " \\ n \\ nBased on your analysis " ) ) ;
assert! ( config
. format_prompt
. contains ( r # "{\"response\": \"Your response text here\"}"# ) ) ;
2025-11-22 12:55:00 -08:00
assert! ( config . format_prompt . contains ( r # "{\"tool_calls\": [{"# ) ) ;
}
#[ test ]
fn test_arch_agent_config_default ( ) {
let config = ArchAgentConfig ::default ( ) ;
assert_eq! ( config . generation_params . temperature , 0.01 ) ; // Different from ArchFunctionConfig
}
#[ test ]
fn test_fix_json_string_valid ( ) {
2025-12-25 21:08:37 -08:00
let handler = ArchFunctionHandler ::new (
" test-model " . to_string ( ) ,
ArchFunctionConfig ::default ( ) ,
" http://localhost:8000 " . to_string ( ) ,
) ;
2025-11-22 12:55:00 -08:00
let json_str = r # "{"name": "test", "value": 123}"# ;
let result = handler . fix_json_string ( json_str ) ;
assert! ( result . is_ok ( ) ) ;
}
#[ test ]
fn test_fix_json_string_missing_bracket ( ) {
2025-12-25 21:08:37 -08:00
let handler = ArchFunctionHandler ::new (
" test-model " . to_string ( ) ,
ArchFunctionConfig ::default ( ) ,
" http://localhost:8000 " . to_string ( ) ,
) ;
2025-11-22 12:55:00 -08:00
let json_str = r # "{"name": "test", "value": 123"# ;
let result = handler . fix_json_string ( json_str ) ;
assert! ( result . is_ok ( ) ) ;
let fixed = result . unwrap ( ) ;
assert! ( fixed . contains ( " } " ) ) ;
}
#[ test ]
fn test_parse_model_response_with_tool_calls ( ) {
2025-12-25 21:08:37 -08:00
let handler = ArchFunctionHandler ::new (
" test-model " . to_string ( ) ,
ArchFunctionConfig ::default ( ) ,
" http://localhost:8000 " . to_string ( ) ,
) ;
let content =
r # "{"tool_calls": [{"name": "get_weather", "arguments": {"location": "NYC"}}]}"# ;
2025-11-22 12:55:00 -08:00
let result = handler . parse_model_response ( content ) ;
assert! ( result . is_valid ) ;
assert_eq! ( result . tool_calls . len ( ) , 1 ) ;
assert_eq! ( result . tool_calls [ 0 ] . function . name , " get_weather " ) ;
}
#[ test ]
fn test_parse_model_response_with_clarification ( ) {
2025-12-25 21:08:37 -08:00
let handler = ArchFunctionHandler ::new (
" test-model " . to_string ( ) ,
ArchFunctionConfig ::default ( ) ,
" http://localhost:8000 " . to_string ( ) ,
) ;
let content =
r # "{"required_functions": ["get_weather"], "clarification": "What location?"}"# ;
2025-11-22 12:55:00 -08:00
let result = handler . parse_model_response ( content ) ;
assert! ( result . is_valid ) ;
assert_eq! ( result . required_functions . len ( ) , 1 ) ;
assert_eq! ( result . clarification , " What location? " ) ;
}
#[ test ]
fn test_convert_data_type_int_to_float ( ) {
2025-12-25 21:08:37 -08:00
let handler = ArchFunctionHandler ::new (
" test-model " . to_string ( ) ,
ArchFunctionConfig ::default ( ) ,
" http://localhost:8000 " . to_string ( ) ,
) ;
2025-11-22 12:55:00 -08:00
let value = json! ( 42 ) ;
let result = handler . convert_data_type ( & value , " float " ) ;
assert! ( result . is_ok ( ) ) ;
assert! ( result . unwrap ( ) . is_f64 ( ) ) ;
}
}
// ============================================================================
// HALLUCINATION DETECTION MODULE
// ============================================================================
/// Mask token types for tracking parsing state
#[ derive(Debug, Clone, Copy, PartialEq, Eq) ]
pub enum MaskToken {
FunctionName ,
ParameterValue ,
ParameterName ,
NotUsed ,
ToolCall ,
}
/// Uncertainty metrics calculated from log probabilities
#[ derive(Debug, Clone) ]
pub struct UncertaintyMetrics {
pub entropy : f64 ,
pub varentropy : f64 ,
pub probability : f64 ,
}
/// Calculates uncertainty metrics from log probabilities
///
/// This is a simplified Rust implementation that avoids torch/tensor dependencies.
/// Uses basic statistical calculations instead of tensor operations.
pub fn calculate_uncertainty ( log_probs : & [ f64 ] ) -> UncertaintyMetrics {
if log_probs . is_empty ( ) {
return UncertaintyMetrics {
entropy : 0.0 ,
varentropy : 0.0 ,
probability : 0.0 ,
} ;
}
// Convert log probabilities to probabilities
let token_probs : Vec < f64 > = log_probs . iter ( ) . map ( | & lp | lp . exp ( ) ) . collect ( ) ;
// Calculate entropy: -sum(p * log(p)) / log(2)
let mut entropy = 0.0 ;
for i in 0 .. log_probs . len ( ) {
entropy - = log_probs [ i ] * token_probs [ i ] ;
}
entropy / = 2_ f64 . ln ( ) ; // Convert to bits
// Calculate variance of entropy
let mut varentropy = 0.0 ;
for i in 0 .. log_probs . len ( ) {
let diff = log_probs [ i ] / 2_ f64 . ln ( ) + entropy ;
varentropy + = token_probs [ i ] * diff * diff ;
}
// Get the top probability
let probability = token_probs . first ( ) . copied ( ) . unwrap_or ( 0.0 ) ;
UncertaintyMetrics {
entropy ,
varentropy ,
probability ,
}
}
/// Checks if uncertainty metrics exceed thresholds
pub fn check_threshold (
entropy : f64 ,
varentropy : f64 ,
thresholds : & HallucinationThresholds ,
) -> bool {
entropy > thresholds . entropy & & varentropy > thresholds . varentropy
}
/// Checks if a parameter is required in the function description
2025-12-25 21:08:37 -08:00
pub fn is_parameter_required ( function_description : & Value , parameter_name : & str ) -> bool {
2025-11-22 12:55:00 -08:00
if let Some ( required ) = function_description . get ( " required " ) {
if let Some ( required_arr ) = required . as_array ( ) {
2025-12-25 21:08:37 -08:00
return required_arr
. iter ( )
. any ( | v | v . as_str ( ) = = Some ( parameter_name ) ) ;
2025-11-22 12:55:00 -08:00
}
}
false
}
/// Checks if a parameter has a specific property
pub fn is_parameter_property (
function_description : & Value ,
parameter_name : & str ,
property_name : & str ,
) -> bool {
if let Some ( properties ) = function_description . get ( " properties " ) {
if let Some ( param_info ) = properties . get ( parameter_name ) {
return param_info . get ( property_name ) . is_some ( ) ;
}
}
false
}
/// State for hallucination detection during streaming
///
/// This is a simplified version of the Python HallucinationState that doesn't
/// require torch/tensor dependencies. It provides the core functionality needed
/// for detecting hallucinations during function calling.
#[ derive(Debug) ]
pub struct HallucinationState {
pub tokens : Vec < String > ,
pub logprobs : Vec < Vec < f64 > > ,
pub state : Option < String > ,
pub mask : Vec < MaskToken > ,
pub parameter_name_done : bool ,
pub hallucination : bool ,
pub error_message : String ,
pub parameter_name : Vec < String > ,
pub token_probs_map : Vec < ( String , f64 , f64 , f64 ) > ,
pub function_properties : HashMap < String , Value > ,
pub open_bracket : bool ,
pub bracket : Option < char > ,
pub function_name : String ,
pub check_parameter_name : HashMap < String , bool > ,
pub thresholds : HallucinationThresholds ,
}
impl HallucinationState {
/// Creates a new HallucinationState with function definitions
pub fn new ( functions : & [ Tool ] ) -> Self {
let function_properties : HashMap < String , Value > = functions
. iter ( )
2025-12-25 21:08:37 -08:00
. map ( | tool | ( tool . function . name . clone ( ) , tool . function . parameters . clone ( ) ) )
2025-11-22 12:55:00 -08:00
. collect ( ) ;
Self {
tokens : Vec ::new ( ) ,
logprobs : Vec ::new ( ) ,
state : None ,
mask : Vec ::new ( ) ,
parameter_name_done : false ,
hallucination : false ,
error_message : String ::new ( ) ,
parameter_name : Vec ::new ( ) ,
token_probs_map : Vec ::new ( ) ,
function_properties ,
open_bracket : false ,
bracket : None ,
function_name : String ::new ( ) ,
check_parameter_name : HashMap ::new ( ) ,
thresholds : HallucinationThresholds ::default ( ) ,
}
}
/// Appends a token and checks for hallucination
pub fn append_and_check_token_hallucination (
& mut self ,
token : String ,
logprob : Vec < f64 > ,
) -> bool {
self . tokens . push ( token ) ;
self . logprobs . push ( logprob ) ;
self . process_token ( ) ;
self . hallucination
}
/// Resets internal parameters
fn reset_parameters ( & mut self ) {
self . state = None ;
self . parameter_name_done = false ;
self . hallucination = false ;
self . error_message . clear ( ) ;
self . open_bracket = false ;
self . bracket = None ;
self . check_parameter_name . clear ( ) ;
}
/// Processes the current token and updates state
fn process_token ( & mut self ) {
let content : String = self . tokens . join ( " " ) . replace ( ' ' , " " ) ;
// Handle end of tool call
if content . ends_with ( END_TOOL_CALL_TOKEN ) {
self . reset_parameters ( ) ;
}
// Function name extraction logic
if self . state . as_deref ( ) = = Some ( " function_name " ) {
2025-12-25 21:08:37 -08:00
if ! FUNC_NAME_END_TOKEN
. iter ( )
. any ( | & t | self . tokens . last ( ) . is_some_and ( | tok | tok = = t ) )
{
2025-11-22 12:55:00 -08:00
self . mask . push ( MaskToken ::FunctionName ) ;
} else {
self . state = None ;
self . get_function_name ( ) ;
}
}
// Check for function name start
2025-12-25 21:08:37 -08:00
if FUNC_NAME_START_PATTERN
. iter ( )
. any ( | & p | content . ends_with ( p ) )
{
2025-11-22 12:55:00 -08:00
self . state = Some ( " function_name " . to_string ( ) ) ;
}
// Parameter name extraction logic
if self . state . as_deref ( ) = = Some ( " parameter_name " )
2025-12-25 21:08:37 -08:00
& & ! PARAMETER_NAME_END_TOKENS
. iter ( )
. any ( | & t | content . ends_with ( t ) )
{
2025-11-22 12:55:00 -08:00
self . mask . push ( MaskToken ::ParameterName ) ;
} else if self . state . as_deref ( ) = = Some ( " parameter_name " )
2025-12-25 21:08:37 -08:00
& & PARAMETER_NAME_END_TOKENS
. iter ( )
. any ( | & t | content . ends_with ( t ) )
{
2025-11-22 12:55:00 -08:00
self . state = None ;
self . parameter_name_done = true ;
self . get_parameter_name ( ) ;
} else if self . parameter_name_done
& & ! self . open_bracket
2025-12-25 21:08:37 -08:00
& & PARAMETER_NAME_START_PATTERN
. iter ( )
. any ( | & p | content . ends_with ( p ) )
{
2025-11-22 12:55:00 -08:00
self . state = Some ( " parameter_name " . to_string ( ) ) ;
}
// First parameter value start
2025-12-25 21:08:37 -08:00
if FIRST_PARAM_NAME_START_PATTERN
. iter ( )
. any ( | & p | content . ends_with ( p ) )
{
2025-11-22 12:55:00 -08:00
self . state = Some ( " parameter_name " . to_string ( ) ) ;
}
// Parameter value extraction logic
if self . state . as_deref ( ) = = Some ( " parameter_value " )
2025-12-25 21:08:37 -08:00
& & ! PARAMETER_VALUE_END_TOKEN
. iter ( )
. any ( | & t | content . ends_with ( t ) )
{
2025-11-22 12:55:00 -08:00
// Check for brackets
if let Some ( last_token ) = self . tokens . last ( ) {
let open_brackets : Vec < char > = last_token
. trim ( )
. chars ( )
. filter ( | & c | c = = '(' | | c = = '{' | | c = = '[' )
. collect ( ) ;
if ! open_brackets . is_empty ( ) {
self . open_bracket = true ;
self . bracket = Some ( open_brackets [ 0 ] ) ;
}
if self . open_bracket {
let closing = match self . bracket {
Some ( '(' ) = > ')' ,
Some ( '{' ) = > '}' ,
Some ( '[' ) = > ']' ,
_ = > '\0' ,
} ;
if last_token . trim ( ) . contains ( closing ) {
self . open_bracket = false ;
self . bracket = None ;
}
}
// Check if token has actual value content
let has_non_punct = last_token . trim ( ) . chars ( ) . any ( | c | ! c . is_ascii_punctuation ( ) ) ;
if has_non_punct & & ! last_token . trim ( ) . is_empty ( ) {
self . mask . push ( MaskToken ::ParameterValue ) ;
// Check hallucination for required parameters without enum
if self . function_properties . contains_key ( & self . function_name ) {
if self . mask . len ( ) > 1
& & self . mask [ self . mask . len ( ) - 2 ] ! = MaskToken ::ParameterValue
& & ! self . parameter_name . is_empty ( )
{
2025-12-25 21:08:37 -08:00
let last_param =
self . parameter_name [ self . parameter_name . len ( ) - 1 ] . clone ( ) ;
if let Some ( func_props ) =
self . function_properties . get ( & self . function_name )
{
2025-11-22 12:55:00 -08:00
if is_parameter_required ( func_props , & last_param )
& & ! is_parameter_property ( func_props , & last_param , " enum " )
& & ! self . check_parameter_name . contains_key ( & last_param )
{
self . check_logprob ( ) ;
self . check_parameter_name . insert ( last_param , true ) ;
}
}
}
} else if ! self . function_name . is_empty ( ) {
self . check_logprob ( ) ;
self . error_message = format! (
" Function name {} not found in function properties " ,
self . function_name
) ;
}
} else {
self . mask . push ( MaskToken ::NotUsed ) ;
}
}
} else if self . state . as_deref ( ) = = Some ( " parameter_value " )
& & ! self . open_bracket
2025-12-25 21:08:37 -08:00
& & PARAMETER_VALUE_END_TOKEN
. iter ( )
. any ( | & t | content . ends_with ( t ) )
{
2025-11-22 12:55:00 -08:00
self . state = None ;
} else if self . parameter_name_done
2025-12-25 21:08:37 -08:00
& & PARAMETER_VALUE_START_PATTERN
. iter ( )
. any ( | & p | content . ends_with ( p ) )
{
2025-11-22 12:55:00 -08:00
self . state = Some ( " parameter_value " . to_string ( ) ) ;
}
// Maintain consistency between tokens and mask
if self . mask . len ( ) ! = self . tokens . len ( ) {
self . mask . push ( MaskToken ::NotUsed ) ;
}
}
/// Checks log probability and detects hallucination
fn check_logprob ( & mut self ) {
if let Some ( probs ) = self . logprobs . last ( ) {
let metrics = calculate_uncertainty ( probs ) ;
if let Some ( token ) = self . tokens . last ( ) {
self . token_probs_map . push ( (
token . clone ( ) ,
metrics . entropy ,
metrics . varentropy ,
metrics . probability ,
) ) ;
if check_threshold ( metrics . entropy , metrics . varentropy , & self . thresholds ) {
self . hallucination = true ;
self . error_message = format! (
" token '{}' is uncertain. Generated response: \n {} " ,
token ,
self . tokens . join ( " " )
) ;
}
}
}
}
/// Counts consecutive tokens of a specific type in the mask
fn count_consecutive_token ( & self , token_type : MaskToken ) -> usize {
if self . mask . is_empty ( ) | | self . mask . last ( ) ! = Some ( & token_type ) {
return 0 ;
}
self . mask
. iter ( )
. rev ( )
. take_while ( | & & t | t = = token_type )
. count ( )
}
/// Extracts the parameter name from recent tokens
fn get_parameter_name ( & mut self ) {
let p_len = self . count_consecutive_token ( MaskToken ::ParameterName ) ;
if p_len > 0 & & self . tokens . len ( ) > 1 {
let start_idx = self . tokens . len ( ) . saturating_sub ( p_len + 1 ) ;
let end_idx = self . tokens . len ( ) . saturating_sub ( 1 ) ;
let parameter_name : String = self . tokens [ start_idx .. end_idx ] . join ( " " ) ;
self . parameter_name . push ( parameter_name ) ;
}
}
/// Extracts the function name from recent tokens
fn get_function_name ( & mut self ) {
let f_len = self . count_consecutive_token ( MaskToken ::FunctionName ) ;
if f_len > 0 & & self . tokens . len ( ) > 1 {
let start_idx = self . tokens . len ( ) . saturating_sub ( f_len + 1 ) ;
let end_idx = self . tokens . len ( ) . saturating_sub ( 1 ) ;
self . function_name = self . tokens [ start_idx .. end_idx ] . join ( " " ) ;
}
}
}
#[ cfg(test) ]
mod hallucination_tests {
use super ::* ;
#[ test ]
fn test_calculate_uncertainty ( ) {
let log_probs = vec! [ - 0.1 , - 2.0 , - 3.0 ] ;
let metrics = calculate_uncertainty ( & log_probs ) ;
assert! ( metrics . entropy > = 0.0 ) ;
assert! ( metrics . varentropy > = 0.0 ) ;
assert! ( metrics . probability > 0.0 & & metrics . probability < = 1.0 ) ;
}
#[ test ]
fn test_calculate_uncertainty_empty ( ) {
let log_probs : Vec < f64 > = vec! [ ] ;
let metrics = calculate_uncertainty ( & log_probs ) ;
assert_eq! ( metrics . entropy , 0.0 ) ;
assert_eq! ( metrics . varentropy , 0.0 ) ;
assert_eq! ( metrics . probability , 0.0 ) ;
}
#[ test ]
fn test_check_threshold ( ) {
let thresholds = HallucinationThresholds ::default ( ) ;
assert! ( check_threshold ( 0.001 , 0.001 , & thresholds ) ) ;
assert! ( ! check_threshold ( 0.00001 , 0.00001 , & thresholds ) ) ;
}
#[ test ]
fn test_is_parameter_required ( ) {
let func_desc = json! ( {
" required " : [ " param1 " , " param2 " ]
} ) ;
assert! ( is_parameter_required ( & func_desc , " param1 " ) ) ;
assert! ( ! is_parameter_required ( & func_desc , " param3 " ) ) ;
}
#[ test ]
fn test_is_parameter_property ( ) {
let func_desc = json! ( {
" properties " : {
" param1 " : {
" type " : " string " ,
" enum " : [ " a " , " b " ]
}
}
} ) ;
assert! ( is_parameter_property ( & func_desc , " param1 " , " enum " ) ) ;
assert! ( ! is_parameter_property ( & func_desc , " param1 " , " default " ) ) ;
}
#[ test ]
fn test_check_value_type ( ) {
let handler = ArchFunctionHandler ::new (
" test-model " . to_string ( ) ,
ArchFunctionConfig ::default ( ) ,
2025-12-25 21:08:37 -08:00
" http://localhost:8000 " . to_string ( ) ,
2025-11-22 12:55:00 -08:00
) ;
// Test integer types
assert! ( handler . check_value_type ( & json! ( 42 ) , " integer " ) ) ;
assert! ( handler . check_value_type ( & json! ( 42 ) , " int " ) ) ;
2025-12-25 21:08:37 -08:00
assert! ( ! handler . check_value_type ( & json! ( 3.15 ) , " integer " ) ) ;
2025-11-22 12:55:00 -08:00
// Test number types (accepts both int and float)
2025-12-25 21:08:37 -08:00
assert! ( handler . check_value_type ( & json! ( 3.15 ) , " number " ) ) ;
2025-11-22 12:55:00 -08:00
assert! ( handler . check_value_type ( & json! ( 42 ) , " number " ) ) ;
2025-12-25 21:08:37 -08:00
assert! ( handler . check_value_type ( & json! ( 3.15 ) , " float " ) ) ;
2025-11-22 12:55:00 -08:00
// Test boolean
assert! ( handler . check_value_type ( & json! ( true ) , " boolean " ) ) ;
assert! ( handler . check_value_type ( & json! ( false ) , " bool " ) ) ;
assert! ( ! handler . check_value_type ( & json! ( " true " ) , " boolean " ) ) ;
// Test string
assert! ( handler . check_value_type ( & json! ( " hello " ) , " string " ) ) ;
assert! ( handler . check_value_type ( & json! ( " hello " ) , " str " ) ) ;
assert! ( ! handler . check_value_type ( & json! ( 123 ) , " string " ) ) ;
// Test array
assert! ( handler . check_value_type ( & json! ( [ 1 , 2 , 3 ] ) , " array " ) ) ;
assert! ( handler . check_value_type ( & json! ( [ 1 , 2 , 3 ] ) , " list " ) ) ;
assert! ( ! handler . check_value_type ( & json! ( { } ) , " array " ) ) ;
// Test object
assert! ( handler . check_value_type ( & json! ( { " key " : " value " } ) , " object " ) ) ;
assert! ( handler . check_value_type ( & json! ( { " key " : " value " } ) , " dict " ) ) ;
assert! ( ! handler . check_value_type ( & json! ( [ ] ) , " object " ) ) ;
// Test unknown type (should return true)
assert! ( handler . check_value_type ( & json! ( 42 ) , " unknown_type " ) ) ;
}
#[ test ]
fn test_validate_or_convert_parameter ( ) {
let handler = ArchFunctionHandler ::new (
" test-model " . to_string ( ) ,
ArchFunctionConfig ::default ( ) ,
2025-12-25 21:08:37 -08:00
" http://localhost:8000 " . to_string ( ) ,
2025-11-22 12:55:00 -08:00
) ;
// Test valid type - no conversion needed
2025-12-25 21:08:37 -08:00
assert! ( handler
. validate_or_convert_parameter ( & json! ( 42 ) , " integer " )
. unwrap ( ) ) ;
assert! ( handler
. validate_or_convert_parameter ( & json! ( " hello " ) , " string " )
. unwrap ( ) ) ;
2025-11-22 12:55:00 -08:00
// Test integer to float conversion (convert_data_type supports this)
let result = handler . validate_or_convert_parameter ( & json! ( 42 ) , " float " ) ;
assert! ( result . is_ok ( ) ) ;
assert! ( result . unwrap ( ) ) ; // Should be valid after conversion
// Test invalid type that cannot be converted
// A string cannot be converted to integer (convert_data_type doesn't support this)
let result = handler . validate_or_convert_parameter ( & json! ( " abc " ) , " integer " ) ;
// Since convert_data_type returns Ok(value.clone()) for unsupported conversions,
// the validation will fail because "abc" string is not an integer
assert! ( ! result . unwrap ( ) ) ;
// Test number accepting both int and float
2025-12-25 21:08:37 -08:00
assert! ( handler
. validate_or_convert_parameter ( & json! ( 42 ) , " number " )
. unwrap ( ) ) ;
assert! ( handler
. validate_or_convert_parameter ( & json! ( 3.15 ) , " number " )
. unwrap ( ) ) ;
2025-11-22 12:55:00 -08:00
}
#[ test ]
fn test_hallucination_state_new ( ) {
let tools = vec! [ Tool {
tool_type : " function " . to_string ( ) ,
function : hermesllm ::apis ::openai ::Function {
name : " test_func " . to_string ( ) ,
description : Some ( " Test function " . to_string ( ) ) ,
parameters : json ! ( { " type " : " object " } ) ,
strict : None ,
} ,
} ] ;
let state = HallucinationState ::new ( & tools ) ;
assert_eq! ( state . tokens . len ( ) , 0 ) ;
assert! ( ! state . hallucination ) ;
assert! ( state . function_properties . contains_key ( " test_func " ) ) ;
}
}