cargo clippy (#660)

This commit is contained in:
Adil Hafeez 2025-12-25 21:08:37 -08:00 committed by GitHub
parent c75e7606f9
commit ca95ffb63d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
62 changed files with 1864 additions and 1187 deletions

View file

@ -1,20 +1,18 @@
use bytes::Bytes;
use eventsource_stream::Eventsource;
use futures::StreamExt;
use hermesllm::apis::openai::{
ChatCompletionsRequest, ChatCompletionsResponse, Choice, FinishReason, FunctionCall, Message,
MessageContent, ResponseMessage, Role, Tool, ToolCall, Usage,
};
use http_body_util::{combinators::BoxBody, BodyExt, Full};
use hyper::body::Incoming;
use hyper::{Request, Response, StatusCode};
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use std::collections::HashMap;
use thiserror::Error;
use tracing::{info, error};
use futures::StreamExt;
use bytes::Bytes;
use http_body_util::{combinators::BoxBody, BodyExt, Full};
use hyper::body::Incoming;
use hyper::{Request, Response, StatusCode};
use eventsource_stream::Eventsource;
use tracing::{error, info};
// ============================================================================
// CONSTANTS FOR HALLUCINATION DETECTION
@ -273,17 +271,14 @@ impl ArchFunctionHandler {
let mut stack: Vec<char> = Vec::new();
let mut fixed_str = String::new();
let matching_bracket: HashMap<char, char> =
[(')', '('), ('}', '{'), (']', '[')]
.iter()
.cloned()
.collect();
let opening_bracket: HashMap<char, char> = matching_bracket
let matching_bracket: HashMap<char, char> = [(')', '('), ('}', '{'), (']', '[')]
.iter()
.map(|(k, v)| (*v, *k))
.cloned()
.collect();
let opening_bracket: HashMap<char, char> =
matching_bracket.iter().map(|(k, v)| (*v, *k)).collect();
for ch in json_str.chars() {
if ch == '{' || ch == '[' || ch == '(' {
stack.push(ch);
@ -332,12 +327,18 @@ impl ArchFunctionHandler {
// Remove markdown code blocks
let mut content = content.trim().to_string();
if content.starts_with("```") && content.ends_with("```") {
content = content.trim_start_matches("```").trim_end_matches("```").to_string();
content = content
.trim_start_matches("```")
.trim_end_matches("```")
.to_string();
if content.starts_with("json") {
content = content.trim_start_matches("json").to_string();
}
// Trim again after removing code blocks to eliminate internal whitespace
content = content.trim_start_matches(r"\n").trim_end_matches(r"\n").to_string();
content = content
.trim_start_matches(r"\n")
.trim_end_matches(r"\n")
.to_string();
content = content.trim().to_string();
// Unescape the quotes: \" -> "
// The model sometimes returns escaped JSON inside markdown blocks
@ -453,12 +454,12 @@ impl ArchFunctionHandler {
/// Helper method to check if a value matches the expected type
fn check_value_type(&self, value: &Value, target_type: &str) -> bool {
match target_type {
"int" | "integer" => value.is_i64() || value.is_u64(),
"int" | "integer" => value.is_i64() || value.is_u64(),
"float" | "number" => value.is_f64() || value.is_i64() || value.is_u64(),
"bool" | "boolean" => value.is_boolean(),
"str" | "string" => value.is_string(),
"list" | "array" => value.is_array(),
"dict" | "object" => value.is_object(),
"bool" | "boolean" => value.is_boolean(),
"str" | "string" => value.is_string(),
"list" | "array" => value.is_array(),
"dict" | "object" => value.is_object(),
_ => true,
}
}
@ -505,15 +506,19 @@ impl ArchFunctionHandler {
let func_name = &tool_call.function.name;
// Parse arguments as JSON
let func_args: HashMap<String, Value> = match serde_json::from_str(&tool_call.function.arguments) {
Ok(args) => args,
Err(e) => {
verification.is_valid = false;
verification.invalid_tool_call = Some(tool_call.clone());
verification.error_message = format!("Failed to parse arguments for function '{}': {}", func_name, e);
break;
}
};
let func_args: HashMap<String, Value> =
match serde_json::from_str(&tool_call.function.arguments) {
Ok(args) => args,
Err(e) => {
verification.is_valid = false;
verification.invalid_tool_call = Some(tool_call.clone());
verification.error_message = format!(
"Failed to parse arguments for function '{}': {}",
func_name, e
);
break;
}
};
// Check if function is available
if let Some(function_params) = functions.get(func_name) {
@ -541,14 +546,23 @@ impl ArchFunctionHandler {
if let Some(properties_obj) = properties.as_object() {
for (param_name, param_value) in &func_args {
if let Some(param_schema) = properties_obj.get(param_name) {
if let Some(target_type) = param_schema.get("type").and_then(|v| v.as_str()) {
if self.config.support_data_types.contains(&target_type.to_string()) {
if let Some(target_type) =
param_schema.get("type").and_then(|v| v.as_str())
{
if self
.config
.support_data_types
.contains(&target_type.to_string())
{
// Validate data type using helper method
match self.validate_or_convert_parameter(param_value, target_type) {
match self
.validate_or_convert_parameter(param_value, target_type)
{
Ok(is_valid) => {
if !is_valid {
verification.is_valid = false;
verification.invalid_tool_call = Some(tool_call.clone());
verification.invalid_tool_call =
Some(tool_call.clone());
verification.error_message = format!(
"Parameter `{}` is expected to have the data type `{}`, got incompatible type.",
param_name, target_type
@ -558,7 +572,8 @@ impl ArchFunctionHandler {
}
Err(_) => {
verification.is_valid = false;
verification.invalid_tool_call = Some(tool_call.clone());
verification.invalid_tool_call =
Some(tool_call.clone());
verification.error_message = format!(
"Parameter `{}` is expected to have the data type `{}`, got incompatible type.",
param_name, target_type
@ -569,7 +584,10 @@ impl ArchFunctionHandler {
} else {
verification.is_valid = false;
verification.invalid_tool_call = Some(tool_call.clone());
verification.error_message = format!("Data type `{}` is not supported.", target_type);
verification.error_message = format!(
"Data type `{}` is not supported.",
target_type
);
break;
}
}
@ -598,11 +616,8 @@ impl ArchFunctionHandler {
/// Formats the system prompt with tools
pub fn format_system_prompt(&self, tools: &[Tool]) -> Result<String> {
let tools_str = self.convert_tools(tools)?;
let system_prompt = self
.config
.task_prompt
.replace("{tools}", &tools_str)
+ &self.config.format_prompt;
let system_prompt =
self.config.task_prompt.replace("{tools}", &tools_str) + &self.config.format_prompt;
Ok(system_prompt)
}
@ -665,15 +680,22 @@ impl ArchFunctionHandler {
// Strip markdown code blocks
if tool_call_msg.starts_with("```") && tool_call_msg.ends_with("```") {
tool_call_msg = tool_call_msg.trim_start_matches("```").trim_end_matches("```").trim().to_string();
tool_call_msg = tool_call_msg
.trim_start_matches("```")
.trim_end_matches("```")
.trim()
.to_string();
if tool_call_msg.starts_with("json") {
tool_call_msg = tool_call_msg.trim_start_matches("json").trim().to_string();
tool_call_msg =
tool_call_msg.trim_start_matches("json").trim().to_string();
}
}
// Extract function name
if let Ok(parsed) = serde_json::from_str::<Value>(&tool_call_msg) {
if let Some(tool_calls_arr) = parsed.get("tool_calls").and_then(|v| v.as_array()) {
if let Some(tool_calls_arr) =
parsed.get("tool_calls").and_then(|v| v.as_array())
{
if let Some(first_tool_call) = tool_calls_arr.first() {
let func_name = first_tool_call
.get("name")
@ -685,8 +707,10 @@ impl ArchFunctionHandler {
"result": content,
});
content = format!("<tool_response>\n{}\n</tool_response>",
serde_json::to_string(&tool_response)?);
content = format!(
"<tool_response>\n{}\n</tool_response>",
serde_json::to_string(&tool_response)?
);
}
}
}
@ -717,7 +741,7 @@ impl ArchFunctionHandler {
if let Some(instruction) = extra_instruction {
if let Some(last) = processed_messages.last_mut() {
if let MessageContent::Text(content) = &mut last.content {
content.push_str("\n");
content.push('\n');
content.push_str(instruction);
}
}
@ -750,13 +774,11 @@ impl ArchFunctionHandler {
for i in (conversation_idx..messages.len()).rev() {
if let MessageContent::Text(content) = &messages[i].content {
num_tokens += content.len() / 4;
if num_tokens >= max_tokens {
if messages[i].role == Role::User {
// Set message_idx to current position and break
// This matches Python's behavior where message_idx is set before break
message_idx = i;
break;
}
if num_tokens >= max_tokens && messages[i].role == Role::User {
// Set message_idx to current position and break
// This matches Python's behavior where message_idx is set before break
message_idx = i;
break;
}
}
// Only update message_idx if we haven't hit the token limit yet
@ -789,7 +811,11 @@ impl ArchFunctionHandler {
}
/// Helper to create a request with VLLM-specific parameters
fn create_request_with_extra_body(&self, messages: Vec<Message>, stream: bool) -> ChatCompletionsRequest {
fn create_request_with_extra_body(
&self,
messages: Vec<Message>,
stream: bool,
) -> ChatCompletionsRequest {
ChatCompletionsRequest {
model: self.model_name.clone(),
messages,
@ -813,24 +839,38 @@ impl ArchFunctionHandler {
}
/// Makes a streaming request and returns the SSE event stream
async fn make_streaming_request(&self, request: ChatCompletionsRequest) -> Result<std::pin::Pin<Box<dyn futures::Stream<Item = std::result::Result<Value, String>> + Send>>> {
let request_body = serde_json::to_string(&request)
.map_err(|e| FunctionCallingError::InvalidModelResponse(format!("Failed to serialize request: {}", e)))?;
async fn make_streaming_request(
&self,
request: ChatCompletionsRequest,
) -> Result<
std::pin::Pin<Box<dyn futures::Stream<Item = std::result::Result<Value, String>> + Send>>,
> {
let request_body = serde_json::to_string(&request).map_err(|e| {
FunctionCallingError::InvalidModelResponse(format!(
"Failed to serialize request: {}",
e
))
})?;
let response = self.http_client
let response = self
.http_client
.post(&self.endpoint_url)
.header("Content-Type", "application/json")
.body(request_body)
.send()
.await
.map_err(|e| FunctionCallingError::HttpError(e))?;
.map_err(FunctionCallingError::HttpError)?;
if !response.status().is_success() {
let status = response.status();
let error_text = response.text().await.unwrap_or_else(|_| "Unknown error".to_string());
return Err(FunctionCallingError::InvalidModelResponse(
format!("HTTP error {}: {}", status, error_text)
));
let error_text = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
return Err(FunctionCallingError::InvalidModelResponse(format!(
"HTTP error {}: {}",
status, error_text
)));
}
// Parse SSE stream
@ -856,38 +896,51 @@ impl ArchFunctionHandler {
}
/// Makes a non-streaming request and returns the response
async fn make_non_streaming_request(&self, request: ChatCompletionsRequest) -> Result<ChatCompletionsResponse> {
let request_body = serde_json::to_string(&request)
.map_err(|e| FunctionCallingError::InvalidModelResponse(format!("Failed to serialize request: {}", e)))?;
async fn make_non_streaming_request(
&self,
request: ChatCompletionsRequest,
) -> Result<ChatCompletionsResponse> {
let request_body = serde_json::to_string(&request).map_err(|e| {
FunctionCallingError::InvalidModelResponse(format!(
"Failed to serialize request: {}",
e
))
})?;
let response = self.http_client
let response = self
.http_client
.post(&self.endpoint_url)
.header("Content-Type", "application/json")
.body(request_body)
.send()
.await
.map_err(|e| FunctionCallingError::HttpError(e))?;
.map_err(FunctionCallingError::HttpError)?;
if !response.status().is_success() {
let status = response.status();
let error_text = response.text().await.unwrap_or_else(|_| "Unknown error".to_string());
return Err(FunctionCallingError::InvalidModelResponse(
format!("HTTP error {}: {}", status, error_text)
));
let error_text = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
return Err(FunctionCallingError::InvalidModelResponse(format!(
"HTTP error {}: {}",
status, error_text
)));
}
let response_text = response.text().await
.map_err(|e| FunctionCallingError::HttpError(e))?;
let response_text = response
.text()
.await
.map_err(FunctionCallingError::HttpError)?;
serde_json::from_str(&response_text)
.map_err(|e| FunctionCallingError::JsonParseError(e))
serde_json::from_str(&response_text).map_err(FunctionCallingError::JsonParseError)
}
pub async fn function_calling_chat(
&self,
request: ChatCompletionsRequest,
) -> Result<ChatCompletionsResponse> {
use tracing::{info, error};
use tracing::{error, info};
info!("[Arch-Function] - ChatCompletion");
@ -899,10 +952,14 @@ impl ArchFunctionHandler {
request.metadata.as_ref(),
)?;
info!("[request to arch-fc]: model: {}, messages count: {}",
self.model_name, messages.len());
info!(
"[request to arch-fc]: model: {}, messages count: {}",
self.model_name,
messages.len()
);
let use_agent_orchestrator = request.metadata
let use_agent_orchestrator = request
.metadata
.as_ref()
.and_then(|m| m.get("use_agent_orchestrator"))
.and_then(|v| v.as_bool())
@ -918,89 +975,95 @@ impl ArchFunctionHandler {
if use_agent_orchestrator {
while let Some(chunk_result) = stream.next().await {
let chunk = chunk_result.map_err(|e| FunctionCallingError::InvalidModelResponse(e))?;
let chunk = chunk_result.map_err(FunctionCallingError::InvalidModelResponse)?;
// Extract content from JSON response
if let Some(choices) = chunk.get("choices").and_then(|v| v.as_array()) {
if let Some(choice) = choices.first() {
if let Some(content) = choice.get("delta")
if let Some(content) = choice
.get("delta")
.and_then(|d| d.get("content"))
.and_then(|c| c.as_str()) {
.and_then(|c| c.as_str())
{
model_response.push_str(content);
}
}
}
}
info!("[Agent Orchestrator]: response received");
} else {
if let Some(tools) = request.tools.as_ref() {
let mut hallucination_state = HallucinationState::new(tools);
let mut has_tool_calls = None;
let mut has_hallucination = false;
} else if let Some(tools) = request.tools.as_ref() {
let mut hallucination_state = HallucinationState::new(tools);
let mut has_tool_calls = None;
let mut has_hallucination = false;
while let Some(chunk_result) = stream.next().await {
let chunk = chunk_result.map_err(|e| FunctionCallingError::InvalidModelResponse(e))?;
while let Some(chunk_result) = stream.next().await {
let chunk = chunk_result.map_err(FunctionCallingError::InvalidModelResponse)?;
// Extract content and logprobs from JSON response
if let Some(choices) = chunk.get("choices").and_then(|v| v.as_array()) {
if let Some(choice) = choices.first() {
if let Some(content) = choice.get("delta")
.and_then(|d| d.get("content"))
.and_then(|c| c.as_str()) {
// Extract content and logprobs from JSON response
if let Some(choices) = chunk.get("choices").and_then(|v| v.as_array()) {
if let Some(choice) = choices.first() {
if let Some(content) = choice
.get("delta")
.and_then(|d| d.get("content"))
.and_then(|c| c.as_str())
{
// Extract logprobs
let logprobs: Vec<f64> = choice
.get("logprobs")
.and_then(|lp| lp.get("content"))
.and_then(|c| c.as_array())
.and_then(|arr| arr.first())
.and_then(|token| token.get("top_logprobs"))
.and_then(|tlp| tlp.as_array())
.map(|arr| {
arr.iter()
.filter_map(|v| v.get("logprob").and_then(|lp| lp.as_f64()))
.collect()
})
.unwrap_or_default();
// Extract logprobs
let logprobs: Vec<f64> = choice.get("logprobs")
.and_then(|lp| lp.get("content"))
.and_then(|c| c.as_array())
.and_then(|arr| arr.first())
.and_then(|token| token.get("top_logprobs"))
.and_then(|tlp| tlp.as_array())
.map(|arr| {
arr.iter()
.filter_map(|v| v.get("logprob").and_then(|lp| lp.as_f64()))
.collect()
})
.unwrap_or_default();
if hallucination_state
.append_and_check_token_hallucination(content.to_string(), logprobs)
{
has_hallucination = true;
break;
}
if hallucination_state.append_and_check_token_hallucination(content.to_string(), logprobs) {
has_hallucination = true;
break;
}
if hallucination_state.tokens.len() > 5 && has_tool_calls.is_none() {
let collected_content = hallucination_state.tokens.join("");
has_tool_calls = Some(collected_content.contains("tool_calls"));
}
if hallucination_state.tokens.len() > 5 && has_tool_calls.is_none() {
let collected_content = hallucination_state.tokens.join("");
has_tool_calls = Some(collected_content.contains("tool_calls"));
}
}
}
}
}
if has_tool_calls == Some(true) && has_hallucination {
info!("[Hallucination]: {}", hallucination_state.error_message);
if has_tool_calls == Some(true) && has_hallucination {
info!("[Hallucination]: {}", hallucination_state.error_message);
let clarify_messages = self.prefill_message(messages.clone(), &self.clarify_prefix);
let clarify_request = self.create_request_with_extra_body(clarify_messages, false);
let clarify_messages = self.prefill_message(messages.clone(), &self.clarify_prefix);
let clarify_request = self.create_request_with_extra_body(clarify_messages, false);
let retry_response = self.make_non_streaming_request(clarify_request).await?;
let retry_response = self.make_non_streaming_request(clarify_request).await?;
if let Some(choice) = retry_response.choices.first() {
if let Some(content) = &choice.message.content {
model_response = content.clone();
}
if let Some(choice) = retry_response.choices.first() {
if let Some(content) = &choice.message.content {
model_response = content.clone();
}
} else {
model_response = hallucination_state.tokens.join("");
}
} else {
while let Some(chunk_result) = stream.next().await {
let chunk = chunk_result.map_err(|e| FunctionCallingError::InvalidModelResponse(e))?;
if let Some(choices) = chunk.get("choices").and_then(|v| v.as_array()) {
if let Some(choice) = choices.first() {
if let Some(content) = choice.get("delta")
.and_then(|d| d.get("content"))
.and_then(|c| c.as_str()) {
model_response.push_str(content);
}
model_response = hallucination_state.tokens.join("");
}
} else {
while let Some(chunk_result) = stream.next().await {
let chunk = chunk_result.map_err(FunctionCallingError::InvalidModelResponse)?;
if let Some(choices) = chunk.get("choices").and_then(|v| v.as_array()) {
if let Some(choice) = choices.first() {
if let Some(content) = choice
.get("delta")
.and_then(|d| d.get("content"))
.and_then(|c| c.as_str())
{
model_response.push_str(content);
}
}
}
@ -1009,10 +1072,17 @@ impl ArchFunctionHandler {
let response_dict = self.parse_model_response(&model_response);
info!("[arch-fc]: raw model response: {}", response_dict.raw_response);
info!(
"[arch-fc]: raw model response: {}",
response_dict.raw_response
);
// General model response (no intent matched - should route to default target)
let model_message = if response_dict.response.as_ref().map_or(false, |s| !s.is_empty()) {
let model_message = if response_dict
.response
.as_ref()
.is_some_and(|s| !s.is_empty())
{
// When arch-fc returns a "response" field, it means no intent was matched
// Return empty content and empty tool_calls so prompt_gateway routes to default target
ResponseMessage {
@ -1053,8 +1123,11 @@ impl ArchFunctionHandler {
let verification = self.verify_tool_calls(tools, &response_dict.tool_calls);
if verification.is_valid {
info!("[Tool calls]: {:?}",
response_dict.tool_calls.iter()
info!(
"[Tool calls]: {:?}",
response_dict
.tool_calls
.iter()
.map(|tc| &tc.function)
.collect::<Vec<_>>()
);
@ -1092,8 +1165,11 @@ impl ArchFunctionHandler {
}
}
} else {
info!("[Tool calls]: {:?}",
response_dict.tool_calls.iter()
info!(
"[Tool calls]: {:?}",
response_dict
.tool_calls
.iter()
.map(|tc| &tc.function)
.collect::<Vec<_>>()
);
@ -1108,7 +1184,10 @@ impl ArchFunctionHandler {
}
}
} else {
error!("Invalid tool calls in response: {}", response_dict.error_message);
error!(
"Invalid tool calls in response: {}",
response_dict.error_message
);
ResponseMessage {
role: Role::Assistant,
content: Some(String::new()),
@ -1243,7 +1322,6 @@ pub async fn function_calling_chat_handler(
req: Request<Incoming>,
llm_provider_url: String,
) -> std::result::Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
use hermesllm::apis::openai::ChatCompletionsRequest;
let whole_body = req.collect().await?.to_bytes();
@ -1255,10 +1333,13 @@ pub async fn function_calling_chat_handler(
let mut response = Response::new(full(
serde_json::json!({
"error": format!("Invalid request body: {}", e)
}).to_string()
})
.to_string(),
));
*response.status_mut() = StatusCode::BAD_REQUEST;
response.headers_mut().insert("Content-Type", "application/json".parse().unwrap());
response
.headers_mut()
.insert("Content-Type", "application/json".parse().unwrap());
return Ok(response);
}
};
@ -1271,24 +1352,31 @@ pub async fn function_calling_chat_handler(
// Parse as ChatCompletionsRequest
let chat_request: ChatCompletionsRequest = match serde_json::from_value(body_json) {
Ok(req) => {
info!("[request body]: {}", serde_json::to_string(&req).unwrap_or_default());
info!(
"[request body]: {}",
serde_json::to_string(&req).unwrap_or_default()
);
req
},
}
Err(e) => {
error!("Failed to parse request body: {}", e);
let mut response = Response::new(full(
serde_json::json!({
"error": format!("Invalid request body: {}", e)
}).to_string()
})
.to_string(),
));
*response.status_mut() = StatusCode::BAD_REQUEST;
response.headers_mut().insert("Content-Type", "application/json".parse().unwrap());
response
.headers_mut()
.insert("Content-Type", "application/json".parse().unwrap());
return Ok(response);
}
};
// Determine which handler to use based on metadata
let use_agent_orchestrator = chat_request.metadata
let use_agent_orchestrator = chat_request
.metadata
.as_ref()
.and_then(|m| m.get("use_agent_orchestrator"))
.and_then(|v| v.as_bool())
@ -1309,7 +1397,10 @@ pub async fn function_calling_chat_handler(
ARCH_FUNCTION_MODEL_NAME.to_string(),
llm_provider_url.clone(),
);
handler.function_handler.function_calling_chat(chat_request).await
handler
.function_handler
.function_calling_chat(chat_request)
.await
} else {
let handler = ArchFunctionHandler::new(
ARCH_FUNCTION_MODEL_NAME.to_string(),
@ -1328,7 +1419,9 @@ pub async fn function_calling_chat_handler(
let mut response = Response::new(full(response_json));
*response.status_mut() = StatusCode::OK;
response.headers_mut().insert("Content-Type", "application/json".parse().unwrap());
response
.headers_mut()
.insert("Content-Type", "application/json".parse().unwrap());
Ok(response)
}
@ -1341,13 +1434,14 @@ pub async fn function_calling_chat_handler(
let mut response = Response::new(full(error_response.to_string()));
*response.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
response.headers_mut().insert("Content-Type", "application/json".parse().unwrap());
response
.headers_mut()
.insert("Content-Type", "application/json".parse().unwrap());
Ok(response)
}
}
}
// ============================================================================
// TESTS
// ============================================================================
@ -1370,10 +1464,13 @@ mod tests {
assert!(config.task_prompt.contains("</tools>\\n\\n"));
// Format prompt should contain literal escaped newlines and proper JSON examples
assert!(config.format_prompt.contains("\\n\\nBased on your analysis"));
assert!(config.format_prompt.contains(r#"{\"response\": \"Your response text here\"}"#));
assert!(config
.format_prompt
.contains("\\n\\nBased on your analysis"));
assert!(config
.format_prompt
.contains(r#"{\"response\": \"Your response text here\"}"#));
assert!(config.format_prompt.contains(r#"{\"tool_calls\": [{"#));
}
#[test]
@ -1384,7 +1481,11 @@ mod tests {
#[test]
fn test_fix_json_string_valid() {
let handler = ArchFunctionHandler::new("test-model".to_string(), ArchFunctionConfig::default(), "http://localhost:8000".to_string());
let handler = ArchFunctionHandler::new(
"test-model".to_string(),
ArchFunctionConfig::default(),
"http://localhost:8000".to_string(),
);
let json_str = r#"{"name": "test", "value": 123}"#;
let result = handler.fix_json_string(json_str);
assert!(result.is_ok());
@ -1392,7 +1493,11 @@ mod tests {
#[test]
fn test_fix_json_string_missing_bracket() {
let handler = ArchFunctionHandler::new("test-model".to_string(), ArchFunctionConfig::default(), "http://localhost:8000".to_string());
let handler = ArchFunctionHandler::new(
"test-model".to_string(),
ArchFunctionConfig::default(),
"http://localhost:8000".to_string(),
);
let json_str = r#"{"name": "test", "value": 123"#;
let result = handler.fix_json_string(json_str);
assert!(result.is_ok());
@ -1402,8 +1507,13 @@ mod tests {
#[test]
fn test_parse_model_response_with_tool_calls() {
let handler = ArchFunctionHandler::new("test-model".to_string(), ArchFunctionConfig::default(), "http://localhost:8000".to_string());
let content = r#"{"tool_calls": [{"name": "get_weather", "arguments": {"location": "NYC"}}]}"#;
let handler = ArchFunctionHandler::new(
"test-model".to_string(),
ArchFunctionConfig::default(),
"http://localhost:8000".to_string(),
);
let content =
r#"{"tool_calls": [{"name": "get_weather", "arguments": {"location": "NYC"}}]}"#;
let result = handler.parse_model_response(content);
assert!(result.is_valid);
@ -1413,8 +1523,13 @@ mod tests {
#[test]
fn test_parse_model_response_with_clarification() {
let handler = ArchFunctionHandler::new("test-model".to_string(), ArchFunctionConfig::default(), "http://localhost:8000".to_string());
let content = r#"{"required_functions": ["get_weather"], "clarification": "What location?"}"#;
let handler = ArchFunctionHandler::new(
"test-model".to_string(),
ArchFunctionConfig::default(),
"http://localhost:8000".to_string(),
);
let content =
r#"{"required_functions": ["get_weather"], "clarification": "What location?"}"#;
let result = handler.parse_model_response(content);
assert!(result.is_valid);
@ -1424,7 +1539,11 @@ mod tests {
#[test]
fn test_convert_data_type_int_to_float() {
let handler = ArchFunctionHandler::new("test-model".to_string(), ArchFunctionConfig::default(), "http://localhost:8000".to_string());
let handler = ArchFunctionHandler::new(
"test-model".to_string(),
ArchFunctionConfig::default(),
"http://localhost:8000".to_string(),
);
let value = json!(42);
let result = handler.convert_data_type(&value, "float");
assert!(result.is_ok());
@ -1504,13 +1623,12 @@ pub fn check_threshold(
}
/// Checks if a parameter is required in the function description
pub fn is_parameter_required(
function_description: &Value,
parameter_name: &str,
) -> bool {
pub fn is_parameter_required(function_description: &Value, parameter_name: &str) -> bool {
if let Some(required) = function_description.get("required") {
if let Some(required_arr) = required.as_array() {
return required_arr.iter().any(|v| v.as_str() == Some(parameter_name));
return required_arr
.iter()
.any(|v| v.as_str() == Some(parameter_name));
}
}
false
@ -1559,12 +1677,7 @@ impl HallucinationState {
pub fn new(functions: &[Tool]) -> Self {
let function_properties: HashMap<String, Value> = functions
.iter()
.map(|tool| {
(
tool.function.name.clone(),
tool.function.parameters.clone(),
)
})
.map(|tool| (tool.function.name.clone(), tool.function.parameters.clone()))
.collect();
Self {
@ -1620,7 +1733,10 @@ impl HallucinationState {
// Function name extraction logic
if self.state.as_deref() == Some("function_name") {
if !FUNC_NAME_END_TOKEN.iter().any(|&t| self.tokens.last().map_or(false, |tok| tok == t)) {
if !FUNC_NAME_END_TOKEN
.iter()
.any(|&t| self.tokens.last().is_some_and(|tok| tok == t))
{
self.mask.push(MaskToken::FunctionName);
} else {
self.state = None;
@ -1629,34 +1745,51 @@ impl HallucinationState {
}
// Check for function name start
if FUNC_NAME_START_PATTERN.iter().any(|&p| content.ends_with(p)) {
if FUNC_NAME_START_PATTERN
.iter()
.any(|&p| content.ends_with(p))
{
self.state = Some("function_name".to_string());
}
// Parameter name extraction logic
if self.state.as_deref() == Some("parameter_name")
&& !PARAMETER_NAME_END_TOKENS.iter().any(|&t| content.ends_with(t)) {
&& !PARAMETER_NAME_END_TOKENS
.iter()
.any(|&t| content.ends_with(t))
{
self.mask.push(MaskToken::ParameterName);
} else if self.state.as_deref() == Some("parameter_name")
&& PARAMETER_NAME_END_TOKENS.iter().any(|&t| content.ends_with(t)) {
&& PARAMETER_NAME_END_TOKENS
.iter()
.any(|&t| content.ends_with(t))
{
self.state = None;
self.parameter_name_done = true;
self.get_parameter_name();
} else if self.parameter_name_done
&& !self.open_bracket
&& PARAMETER_NAME_START_PATTERN.iter().any(|&p| content.ends_with(p)) {
&& PARAMETER_NAME_START_PATTERN
.iter()
.any(|&p| content.ends_with(p))
{
self.state = Some("parameter_name".to_string());
}
// First parameter value start
if FIRST_PARAM_NAME_START_PATTERN.iter().any(|&p| content.ends_with(p)) {
if FIRST_PARAM_NAME_START_PATTERN
.iter()
.any(|&p| content.ends_with(p))
{
self.state = Some("parameter_name".to_string());
}
// Parameter value extraction logic
if self.state.as_deref() == Some("parameter_value")
&& !PARAMETER_VALUE_END_TOKEN.iter().any(|&t| content.ends_with(t)) {
&& !PARAMETER_VALUE_END_TOKEN
.iter()
.any(|&t| content.ends_with(t))
{
// Check for brackets
if let Some(last_token) = self.tokens.last() {
let open_brackets: Vec<char> = last_token
@ -1694,8 +1827,11 @@ impl HallucinationState {
&& self.mask[self.mask.len() - 2] != MaskToken::ParameterValue
&& !self.parameter_name.is_empty()
{
let last_param = self.parameter_name[self.parameter_name.len() - 1].clone();
if let Some(func_props) = self.function_properties.get(&self.function_name) {
let last_param =
self.parameter_name[self.parameter_name.len() - 1].clone();
if let Some(func_props) =
self.function_properties.get(&self.function_name)
{
if is_parameter_required(func_props, &last_param)
&& !is_parameter_property(func_props, &last_param, "enum")
&& !self.check_parameter_name.contains_key(&last_param)
@ -1718,10 +1854,16 @@ impl HallucinationState {
}
} else if self.state.as_deref() == Some("parameter_value")
&& !self.open_bracket
&& PARAMETER_VALUE_END_TOKEN.iter().any(|&t| content.ends_with(t)) {
&& PARAMETER_VALUE_END_TOKEN
.iter()
.any(|&t| content.ends_with(t))
{
self.state = None;
} else if self.parameter_name_done
&& PARAMETER_VALUE_START_PATTERN.iter().any(|&p| content.ends_with(p)) {
&& PARAMETER_VALUE_START_PATTERN
.iter()
.any(|&p| content.ends_with(p))
{
self.state = Some("parameter_value".to_string());
}
@ -1848,18 +1990,18 @@ mod hallucination_tests {
let handler = ArchFunctionHandler::new(
"test-model".to_string(),
ArchFunctionConfig::default(),
"http://localhost:8000".to_string()
"http://localhost:8000".to_string(),
);
// Test integer types
assert!(handler.check_value_type(&json!(42), "integer"));
assert!(handler.check_value_type(&json!(42), "int"));
assert!(!handler.check_value_type(&json!(3.14), "integer"));
assert!(!handler.check_value_type(&json!(3.15), "integer"));
// Test number types (accepts both int and float)
assert!(handler.check_value_type(&json!(3.14), "number"));
assert!(handler.check_value_type(&json!(3.15), "number"));
assert!(handler.check_value_type(&json!(42), "number"));
assert!(handler.check_value_type(&json!(3.14), "float"));
assert!(handler.check_value_type(&json!(3.15), "float"));
// Test boolean
assert!(handler.check_value_type(&json!(true), "boolean"));
@ -1890,12 +2032,16 @@ mod hallucination_tests {
let handler = ArchFunctionHandler::new(
"test-model".to_string(),
ArchFunctionConfig::default(),
"http://localhost:8000".to_string()
"http://localhost:8000".to_string(),
);
// Test valid type - no conversion needed
assert!(handler.validate_or_convert_parameter(&json!(42), "integer").unwrap());
assert!(handler.validate_or_convert_parameter(&json!("hello"), "string").unwrap());
assert!(handler
.validate_or_convert_parameter(&json!(42), "integer")
.unwrap());
assert!(handler
.validate_or_convert_parameter(&json!("hello"), "string")
.unwrap());
// Test integer to float conversion (convert_data_type supports this)
let result = handler.validate_or_convert_parameter(&json!(42), "float");
@ -1910,8 +2056,12 @@ mod hallucination_tests {
assert!(!result.unwrap());
// Test number accepting both int and float
assert!(handler.validate_or_convert_parameter(&json!(42), "number").unwrap());
assert!(handler.validate_or_convert_parameter(&json!(3.14), "number").unwrap());
assert!(handler
.validate_or_convert_parameter(&json!(42), "number")
.unwrap());
assert!(handler
.validate_or_convert_parameter(&json!(3.15), "number")
.unwrap());
}
#[test]