mirror of
https://github.com/katanemo/plano.git
synced 2026-05-18 13:45:15 +02:00
cargo clippy (#660)
This commit is contained in:
parent
c75e7606f9
commit
ca95ffb63d
62 changed files with 1864 additions and 1187 deletions
|
|
@ -1,20 +1,18 @@
|
|||
use bytes::Bytes;
|
||||
use eventsource_stream::Eventsource;
|
||||
use futures::StreamExt;
|
||||
use hermesllm::apis::openai::{
|
||||
ChatCompletionsRequest, ChatCompletionsResponse, Choice, FinishReason, FunctionCall, Message,
|
||||
MessageContent, ResponseMessage, Role, Tool, ToolCall, Usage,
|
||||
};
|
||||
use http_body_util::{combinators::BoxBody, BodyExt, Full};
|
||||
use hyper::body::Incoming;
|
||||
use hyper::{Request, Response, StatusCode};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{json, Value};
|
||||
use std::collections::HashMap;
|
||||
use thiserror::Error;
|
||||
use tracing::{info, error};
|
||||
use futures::StreamExt;
|
||||
use bytes::Bytes;
|
||||
use http_body_util::{combinators::BoxBody, BodyExt, Full};
|
||||
use hyper::body::Incoming;
|
||||
use hyper::{Request, Response, StatusCode};
|
||||
use eventsource_stream::Eventsource;
|
||||
|
||||
|
||||
use tracing::{error, info};
|
||||
|
||||
// ============================================================================
|
||||
// CONSTANTS FOR HALLUCINATION DETECTION
|
||||
|
|
@ -273,17 +271,14 @@ impl ArchFunctionHandler {
|
|||
let mut stack: Vec<char> = Vec::new();
|
||||
let mut fixed_str = String::new();
|
||||
|
||||
let matching_bracket: HashMap<char, char> =
|
||||
[(')', '('), ('}', '{'), (']', '[')]
|
||||
.iter()
|
||||
.cloned()
|
||||
.collect();
|
||||
|
||||
let opening_bracket: HashMap<char, char> = matching_bracket
|
||||
let matching_bracket: HashMap<char, char> = [(')', '('), ('}', '{'), (']', '[')]
|
||||
.iter()
|
||||
.map(|(k, v)| (*v, *k))
|
||||
.cloned()
|
||||
.collect();
|
||||
|
||||
let opening_bracket: HashMap<char, char> =
|
||||
matching_bracket.iter().map(|(k, v)| (*v, *k)).collect();
|
||||
|
||||
for ch in json_str.chars() {
|
||||
if ch == '{' || ch == '[' || ch == '(' {
|
||||
stack.push(ch);
|
||||
|
|
@ -332,12 +327,18 @@ impl ArchFunctionHandler {
|
|||
// Remove markdown code blocks
|
||||
let mut content = content.trim().to_string();
|
||||
if content.starts_with("```") && content.ends_with("```") {
|
||||
content = content.trim_start_matches("```").trim_end_matches("```").to_string();
|
||||
content = content
|
||||
.trim_start_matches("```")
|
||||
.trim_end_matches("```")
|
||||
.to_string();
|
||||
if content.starts_with("json") {
|
||||
content = content.trim_start_matches("json").to_string();
|
||||
}
|
||||
// Trim again after removing code blocks to eliminate internal whitespace
|
||||
content = content.trim_start_matches(r"\n").trim_end_matches(r"\n").to_string();
|
||||
content = content
|
||||
.trim_start_matches(r"\n")
|
||||
.trim_end_matches(r"\n")
|
||||
.to_string();
|
||||
content = content.trim().to_string();
|
||||
// Unescape the quotes: \" -> "
|
||||
// The model sometimes returns escaped JSON inside markdown blocks
|
||||
|
|
@ -453,12 +454,12 @@ impl ArchFunctionHandler {
|
|||
/// Helper method to check if a value matches the expected type
|
||||
fn check_value_type(&self, value: &Value, target_type: &str) -> bool {
|
||||
match target_type {
|
||||
"int" | "integer" => value.is_i64() || value.is_u64(),
|
||||
"int" | "integer" => value.is_i64() || value.is_u64(),
|
||||
"float" | "number" => value.is_f64() || value.is_i64() || value.is_u64(),
|
||||
"bool" | "boolean" => value.is_boolean(),
|
||||
"str" | "string" => value.is_string(),
|
||||
"list" | "array" => value.is_array(),
|
||||
"dict" | "object" => value.is_object(),
|
||||
"bool" | "boolean" => value.is_boolean(),
|
||||
"str" | "string" => value.is_string(),
|
||||
"list" | "array" => value.is_array(),
|
||||
"dict" | "object" => value.is_object(),
|
||||
_ => true,
|
||||
}
|
||||
}
|
||||
|
|
@ -505,15 +506,19 @@ impl ArchFunctionHandler {
|
|||
let func_name = &tool_call.function.name;
|
||||
|
||||
// Parse arguments as JSON
|
||||
let func_args: HashMap<String, Value> = match serde_json::from_str(&tool_call.function.arguments) {
|
||||
Ok(args) => args,
|
||||
Err(e) => {
|
||||
verification.is_valid = false;
|
||||
verification.invalid_tool_call = Some(tool_call.clone());
|
||||
verification.error_message = format!("Failed to parse arguments for function '{}': {}", func_name, e);
|
||||
break;
|
||||
}
|
||||
};
|
||||
let func_args: HashMap<String, Value> =
|
||||
match serde_json::from_str(&tool_call.function.arguments) {
|
||||
Ok(args) => args,
|
||||
Err(e) => {
|
||||
verification.is_valid = false;
|
||||
verification.invalid_tool_call = Some(tool_call.clone());
|
||||
verification.error_message = format!(
|
||||
"Failed to parse arguments for function '{}': {}",
|
||||
func_name, e
|
||||
);
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
// Check if function is available
|
||||
if let Some(function_params) = functions.get(func_name) {
|
||||
|
|
@ -541,14 +546,23 @@ impl ArchFunctionHandler {
|
|||
if let Some(properties_obj) = properties.as_object() {
|
||||
for (param_name, param_value) in &func_args {
|
||||
if let Some(param_schema) = properties_obj.get(param_name) {
|
||||
if let Some(target_type) = param_schema.get("type").and_then(|v| v.as_str()) {
|
||||
if self.config.support_data_types.contains(&target_type.to_string()) {
|
||||
if let Some(target_type) =
|
||||
param_schema.get("type").and_then(|v| v.as_str())
|
||||
{
|
||||
if self
|
||||
.config
|
||||
.support_data_types
|
||||
.contains(&target_type.to_string())
|
||||
{
|
||||
// Validate data type using helper method
|
||||
match self.validate_or_convert_parameter(param_value, target_type) {
|
||||
match self
|
||||
.validate_or_convert_parameter(param_value, target_type)
|
||||
{
|
||||
Ok(is_valid) => {
|
||||
if !is_valid {
|
||||
verification.is_valid = false;
|
||||
verification.invalid_tool_call = Some(tool_call.clone());
|
||||
verification.invalid_tool_call =
|
||||
Some(tool_call.clone());
|
||||
verification.error_message = format!(
|
||||
"Parameter `{}` is expected to have the data type `{}`, got incompatible type.",
|
||||
param_name, target_type
|
||||
|
|
@ -558,7 +572,8 @@ impl ArchFunctionHandler {
|
|||
}
|
||||
Err(_) => {
|
||||
verification.is_valid = false;
|
||||
verification.invalid_tool_call = Some(tool_call.clone());
|
||||
verification.invalid_tool_call =
|
||||
Some(tool_call.clone());
|
||||
verification.error_message = format!(
|
||||
"Parameter `{}` is expected to have the data type `{}`, got incompatible type.",
|
||||
param_name, target_type
|
||||
|
|
@ -569,7 +584,10 @@ impl ArchFunctionHandler {
|
|||
} else {
|
||||
verification.is_valid = false;
|
||||
verification.invalid_tool_call = Some(tool_call.clone());
|
||||
verification.error_message = format!("Data type `{}` is not supported.", target_type);
|
||||
verification.error_message = format!(
|
||||
"Data type `{}` is not supported.",
|
||||
target_type
|
||||
);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
|
@ -598,11 +616,8 @@ impl ArchFunctionHandler {
|
|||
/// Formats the system prompt with tools
|
||||
pub fn format_system_prompt(&self, tools: &[Tool]) -> Result<String> {
|
||||
let tools_str = self.convert_tools(tools)?;
|
||||
let system_prompt = self
|
||||
.config
|
||||
.task_prompt
|
||||
.replace("{tools}", &tools_str)
|
||||
+ &self.config.format_prompt;
|
||||
let system_prompt =
|
||||
self.config.task_prompt.replace("{tools}", &tools_str) + &self.config.format_prompt;
|
||||
|
||||
Ok(system_prompt)
|
||||
}
|
||||
|
|
@ -665,15 +680,22 @@ impl ArchFunctionHandler {
|
|||
|
||||
// Strip markdown code blocks
|
||||
if tool_call_msg.starts_with("```") && tool_call_msg.ends_with("```") {
|
||||
tool_call_msg = tool_call_msg.trim_start_matches("```").trim_end_matches("```").trim().to_string();
|
||||
tool_call_msg = tool_call_msg
|
||||
.trim_start_matches("```")
|
||||
.trim_end_matches("```")
|
||||
.trim()
|
||||
.to_string();
|
||||
if tool_call_msg.starts_with("json") {
|
||||
tool_call_msg = tool_call_msg.trim_start_matches("json").trim().to_string();
|
||||
tool_call_msg =
|
||||
tool_call_msg.trim_start_matches("json").trim().to_string();
|
||||
}
|
||||
}
|
||||
|
||||
// Extract function name
|
||||
if let Ok(parsed) = serde_json::from_str::<Value>(&tool_call_msg) {
|
||||
if let Some(tool_calls_arr) = parsed.get("tool_calls").and_then(|v| v.as_array()) {
|
||||
if let Some(tool_calls_arr) =
|
||||
parsed.get("tool_calls").and_then(|v| v.as_array())
|
||||
{
|
||||
if let Some(first_tool_call) = tool_calls_arr.first() {
|
||||
let func_name = first_tool_call
|
||||
.get("name")
|
||||
|
|
@ -685,8 +707,10 @@ impl ArchFunctionHandler {
|
|||
"result": content,
|
||||
});
|
||||
|
||||
content = format!("<tool_response>\n{}\n</tool_response>",
|
||||
serde_json::to_string(&tool_response)?);
|
||||
content = format!(
|
||||
"<tool_response>\n{}\n</tool_response>",
|
||||
serde_json::to_string(&tool_response)?
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -717,7 +741,7 @@ impl ArchFunctionHandler {
|
|||
if let Some(instruction) = extra_instruction {
|
||||
if let Some(last) = processed_messages.last_mut() {
|
||||
if let MessageContent::Text(content) = &mut last.content {
|
||||
content.push_str("\n");
|
||||
content.push('\n');
|
||||
content.push_str(instruction);
|
||||
}
|
||||
}
|
||||
|
|
@ -750,13 +774,11 @@ impl ArchFunctionHandler {
|
|||
for i in (conversation_idx..messages.len()).rev() {
|
||||
if let MessageContent::Text(content) = &messages[i].content {
|
||||
num_tokens += content.len() / 4;
|
||||
if num_tokens >= max_tokens {
|
||||
if messages[i].role == Role::User {
|
||||
// Set message_idx to current position and break
|
||||
// This matches Python's behavior where message_idx is set before break
|
||||
message_idx = i;
|
||||
break;
|
||||
}
|
||||
if num_tokens >= max_tokens && messages[i].role == Role::User {
|
||||
// Set message_idx to current position and break
|
||||
// This matches Python's behavior where message_idx is set before break
|
||||
message_idx = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
// Only update message_idx if we haven't hit the token limit yet
|
||||
|
|
@ -789,7 +811,11 @@ impl ArchFunctionHandler {
|
|||
}
|
||||
|
||||
/// Helper to create a request with VLLM-specific parameters
|
||||
fn create_request_with_extra_body(&self, messages: Vec<Message>, stream: bool) -> ChatCompletionsRequest {
|
||||
fn create_request_with_extra_body(
|
||||
&self,
|
||||
messages: Vec<Message>,
|
||||
stream: bool,
|
||||
) -> ChatCompletionsRequest {
|
||||
ChatCompletionsRequest {
|
||||
model: self.model_name.clone(),
|
||||
messages,
|
||||
|
|
@ -813,24 +839,38 @@ impl ArchFunctionHandler {
|
|||
}
|
||||
|
||||
/// Makes a streaming request and returns the SSE event stream
|
||||
async fn make_streaming_request(&self, request: ChatCompletionsRequest) -> Result<std::pin::Pin<Box<dyn futures::Stream<Item = std::result::Result<Value, String>> + Send>>> {
|
||||
let request_body = serde_json::to_string(&request)
|
||||
.map_err(|e| FunctionCallingError::InvalidModelResponse(format!("Failed to serialize request: {}", e)))?;
|
||||
async fn make_streaming_request(
|
||||
&self,
|
||||
request: ChatCompletionsRequest,
|
||||
) -> Result<
|
||||
std::pin::Pin<Box<dyn futures::Stream<Item = std::result::Result<Value, String>> + Send>>,
|
||||
> {
|
||||
let request_body = serde_json::to_string(&request).map_err(|e| {
|
||||
FunctionCallingError::InvalidModelResponse(format!(
|
||||
"Failed to serialize request: {}",
|
||||
e
|
||||
))
|
||||
})?;
|
||||
|
||||
let response = self.http_client
|
||||
let response = self
|
||||
.http_client
|
||||
.post(&self.endpoint_url)
|
||||
.header("Content-Type", "application/json")
|
||||
.body(request_body)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| FunctionCallingError::HttpError(e))?;
|
||||
.map_err(FunctionCallingError::HttpError)?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let error_text = response.text().await.unwrap_or_else(|_| "Unknown error".to_string());
|
||||
return Err(FunctionCallingError::InvalidModelResponse(
|
||||
format!("HTTP error {}: {}", status, error_text)
|
||||
));
|
||||
let error_text = response
|
||||
.text()
|
||||
.await
|
||||
.unwrap_or_else(|_| "Unknown error".to_string());
|
||||
return Err(FunctionCallingError::InvalidModelResponse(format!(
|
||||
"HTTP error {}: {}",
|
||||
status, error_text
|
||||
)));
|
||||
}
|
||||
|
||||
// Parse SSE stream
|
||||
|
|
@ -856,38 +896,51 @@ impl ArchFunctionHandler {
|
|||
}
|
||||
|
||||
/// Makes a non-streaming request and returns the response
|
||||
async fn make_non_streaming_request(&self, request: ChatCompletionsRequest) -> Result<ChatCompletionsResponse> {
|
||||
let request_body = serde_json::to_string(&request)
|
||||
.map_err(|e| FunctionCallingError::InvalidModelResponse(format!("Failed to serialize request: {}", e)))?;
|
||||
async fn make_non_streaming_request(
|
||||
&self,
|
||||
request: ChatCompletionsRequest,
|
||||
) -> Result<ChatCompletionsResponse> {
|
||||
let request_body = serde_json::to_string(&request).map_err(|e| {
|
||||
FunctionCallingError::InvalidModelResponse(format!(
|
||||
"Failed to serialize request: {}",
|
||||
e
|
||||
))
|
||||
})?;
|
||||
|
||||
let response = self.http_client
|
||||
let response = self
|
||||
.http_client
|
||||
.post(&self.endpoint_url)
|
||||
.header("Content-Type", "application/json")
|
||||
.body(request_body)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| FunctionCallingError::HttpError(e))?;
|
||||
.map_err(FunctionCallingError::HttpError)?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let error_text = response.text().await.unwrap_or_else(|_| "Unknown error".to_string());
|
||||
return Err(FunctionCallingError::InvalidModelResponse(
|
||||
format!("HTTP error {}: {}", status, error_text)
|
||||
));
|
||||
let error_text = response
|
||||
.text()
|
||||
.await
|
||||
.unwrap_or_else(|_| "Unknown error".to_string());
|
||||
return Err(FunctionCallingError::InvalidModelResponse(format!(
|
||||
"HTTP error {}: {}",
|
||||
status, error_text
|
||||
)));
|
||||
}
|
||||
|
||||
let response_text = response.text().await
|
||||
.map_err(|e| FunctionCallingError::HttpError(e))?;
|
||||
let response_text = response
|
||||
.text()
|
||||
.await
|
||||
.map_err(FunctionCallingError::HttpError)?;
|
||||
|
||||
serde_json::from_str(&response_text)
|
||||
.map_err(|e| FunctionCallingError::JsonParseError(e))
|
||||
serde_json::from_str(&response_text).map_err(FunctionCallingError::JsonParseError)
|
||||
}
|
||||
|
||||
pub async fn function_calling_chat(
|
||||
&self,
|
||||
request: ChatCompletionsRequest,
|
||||
) -> Result<ChatCompletionsResponse> {
|
||||
use tracing::{info, error};
|
||||
use tracing::{error, info};
|
||||
|
||||
info!("[Arch-Function] - ChatCompletion");
|
||||
|
||||
|
|
@ -899,10 +952,14 @@ impl ArchFunctionHandler {
|
|||
request.metadata.as_ref(),
|
||||
)?;
|
||||
|
||||
info!("[request to arch-fc]: model: {}, messages count: {}",
|
||||
self.model_name, messages.len());
|
||||
info!(
|
||||
"[request to arch-fc]: model: {}, messages count: {}",
|
||||
self.model_name,
|
||||
messages.len()
|
||||
);
|
||||
|
||||
let use_agent_orchestrator = request.metadata
|
||||
let use_agent_orchestrator = request
|
||||
.metadata
|
||||
.as_ref()
|
||||
.and_then(|m| m.get("use_agent_orchestrator"))
|
||||
.and_then(|v| v.as_bool())
|
||||
|
|
@ -918,89 +975,95 @@ impl ArchFunctionHandler {
|
|||
|
||||
if use_agent_orchestrator {
|
||||
while let Some(chunk_result) = stream.next().await {
|
||||
let chunk = chunk_result.map_err(|e| FunctionCallingError::InvalidModelResponse(e))?;
|
||||
let chunk = chunk_result.map_err(FunctionCallingError::InvalidModelResponse)?;
|
||||
// Extract content from JSON response
|
||||
if let Some(choices) = chunk.get("choices").and_then(|v| v.as_array()) {
|
||||
if let Some(choice) = choices.first() {
|
||||
if let Some(content) = choice.get("delta")
|
||||
if let Some(content) = choice
|
||||
.get("delta")
|
||||
.and_then(|d| d.get("content"))
|
||||
.and_then(|c| c.as_str()) {
|
||||
.and_then(|c| c.as_str())
|
||||
{
|
||||
model_response.push_str(content);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
info!("[Agent Orchestrator]: response received");
|
||||
} else {
|
||||
if let Some(tools) = request.tools.as_ref() {
|
||||
let mut hallucination_state = HallucinationState::new(tools);
|
||||
let mut has_tool_calls = None;
|
||||
let mut has_hallucination = false;
|
||||
} else if let Some(tools) = request.tools.as_ref() {
|
||||
let mut hallucination_state = HallucinationState::new(tools);
|
||||
let mut has_tool_calls = None;
|
||||
let mut has_hallucination = false;
|
||||
|
||||
while let Some(chunk_result) = stream.next().await {
|
||||
let chunk = chunk_result.map_err(|e| FunctionCallingError::InvalidModelResponse(e))?;
|
||||
while let Some(chunk_result) = stream.next().await {
|
||||
let chunk = chunk_result.map_err(FunctionCallingError::InvalidModelResponse)?;
|
||||
|
||||
// Extract content and logprobs from JSON response
|
||||
if let Some(choices) = chunk.get("choices").and_then(|v| v.as_array()) {
|
||||
if let Some(choice) = choices.first() {
|
||||
if let Some(content) = choice.get("delta")
|
||||
.and_then(|d| d.get("content"))
|
||||
.and_then(|c| c.as_str()) {
|
||||
// Extract content and logprobs from JSON response
|
||||
if let Some(choices) = chunk.get("choices").and_then(|v| v.as_array()) {
|
||||
if let Some(choice) = choices.first() {
|
||||
if let Some(content) = choice
|
||||
.get("delta")
|
||||
.and_then(|d| d.get("content"))
|
||||
.and_then(|c| c.as_str())
|
||||
{
|
||||
// Extract logprobs
|
||||
let logprobs: Vec<f64> = choice
|
||||
.get("logprobs")
|
||||
.and_then(|lp| lp.get("content"))
|
||||
.and_then(|c| c.as_array())
|
||||
.and_then(|arr| arr.first())
|
||||
.and_then(|token| token.get("top_logprobs"))
|
||||
.and_then(|tlp| tlp.as_array())
|
||||
.map(|arr| {
|
||||
arr.iter()
|
||||
.filter_map(|v| v.get("logprob").and_then(|lp| lp.as_f64()))
|
||||
.collect()
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
// Extract logprobs
|
||||
let logprobs: Vec<f64> = choice.get("logprobs")
|
||||
.and_then(|lp| lp.get("content"))
|
||||
.and_then(|c| c.as_array())
|
||||
.and_then(|arr| arr.first())
|
||||
.and_then(|token| token.get("top_logprobs"))
|
||||
.and_then(|tlp| tlp.as_array())
|
||||
.map(|arr| {
|
||||
arr.iter()
|
||||
.filter_map(|v| v.get("logprob").and_then(|lp| lp.as_f64()))
|
||||
.collect()
|
||||
})
|
||||
.unwrap_or_default();
|
||||
if hallucination_state
|
||||
.append_and_check_token_hallucination(content.to_string(), logprobs)
|
||||
{
|
||||
has_hallucination = true;
|
||||
break;
|
||||
}
|
||||
|
||||
if hallucination_state.append_and_check_token_hallucination(content.to_string(), logprobs) {
|
||||
has_hallucination = true;
|
||||
break;
|
||||
}
|
||||
|
||||
if hallucination_state.tokens.len() > 5 && has_tool_calls.is_none() {
|
||||
let collected_content = hallucination_state.tokens.join("");
|
||||
has_tool_calls = Some(collected_content.contains("tool_calls"));
|
||||
}
|
||||
if hallucination_state.tokens.len() > 5 && has_tool_calls.is_none() {
|
||||
let collected_content = hallucination_state.tokens.join("");
|
||||
has_tool_calls = Some(collected_content.contains("tool_calls"));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if has_tool_calls == Some(true) && has_hallucination {
|
||||
info!("[Hallucination]: {}", hallucination_state.error_message);
|
||||
if has_tool_calls == Some(true) && has_hallucination {
|
||||
info!("[Hallucination]: {}", hallucination_state.error_message);
|
||||
|
||||
let clarify_messages = self.prefill_message(messages.clone(), &self.clarify_prefix);
|
||||
let clarify_request = self.create_request_with_extra_body(clarify_messages, false);
|
||||
let clarify_messages = self.prefill_message(messages.clone(), &self.clarify_prefix);
|
||||
let clarify_request = self.create_request_with_extra_body(clarify_messages, false);
|
||||
|
||||
let retry_response = self.make_non_streaming_request(clarify_request).await?;
|
||||
let retry_response = self.make_non_streaming_request(clarify_request).await?;
|
||||
|
||||
if let Some(choice) = retry_response.choices.first() {
|
||||
if let Some(content) = &choice.message.content {
|
||||
model_response = content.clone();
|
||||
}
|
||||
if let Some(choice) = retry_response.choices.first() {
|
||||
if let Some(content) = &choice.message.content {
|
||||
model_response = content.clone();
|
||||
}
|
||||
} else {
|
||||
model_response = hallucination_state.tokens.join("");
|
||||
}
|
||||
} else {
|
||||
while let Some(chunk_result) = stream.next().await {
|
||||
let chunk = chunk_result.map_err(|e| FunctionCallingError::InvalidModelResponse(e))?;
|
||||
if let Some(choices) = chunk.get("choices").and_then(|v| v.as_array()) {
|
||||
if let Some(choice) = choices.first() {
|
||||
if let Some(content) = choice.get("delta")
|
||||
.and_then(|d| d.get("content"))
|
||||
.and_then(|c| c.as_str()) {
|
||||
model_response.push_str(content);
|
||||
}
|
||||
model_response = hallucination_state.tokens.join("");
|
||||
}
|
||||
} else {
|
||||
while let Some(chunk_result) = stream.next().await {
|
||||
let chunk = chunk_result.map_err(FunctionCallingError::InvalidModelResponse)?;
|
||||
if let Some(choices) = chunk.get("choices").and_then(|v| v.as_array()) {
|
||||
if let Some(choice) = choices.first() {
|
||||
if let Some(content) = choice
|
||||
.get("delta")
|
||||
.and_then(|d| d.get("content"))
|
||||
.and_then(|c| c.as_str())
|
||||
{
|
||||
model_response.push_str(content);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1009,10 +1072,17 @@ impl ArchFunctionHandler {
|
|||
|
||||
let response_dict = self.parse_model_response(&model_response);
|
||||
|
||||
info!("[arch-fc]: raw model response: {}", response_dict.raw_response);
|
||||
info!(
|
||||
"[arch-fc]: raw model response: {}",
|
||||
response_dict.raw_response
|
||||
);
|
||||
|
||||
// General model response (no intent matched - should route to default target)
|
||||
let model_message = if response_dict.response.as_ref().map_or(false, |s| !s.is_empty()) {
|
||||
let model_message = if response_dict
|
||||
.response
|
||||
.as_ref()
|
||||
.is_some_and(|s| !s.is_empty())
|
||||
{
|
||||
// When arch-fc returns a "response" field, it means no intent was matched
|
||||
// Return empty content and empty tool_calls so prompt_gateway routes to default target
|
||||
ResponseMessage {
|
||||
|
|
@ -1053,8 +1123,11 @@ impl ArchFunctionHandler {
|
|||
let verification = self.verify_tool_calls(tools, &response_dict.tool_calls);
|
||||
|
||||
if verification.is_valid {
|
||||
info!("[Tool calls]: {:?}",
|
||||
response_dict.tool_calls.iter()
|
||||
info!(
|
||||
"[Tool calls]: {:?}",
|
||||
response_dict
|
||||
.tool_calls
|
||||
.iter()
|
||||
.map(|tc| &tc.function)
|
||||
.collect::<Vec<_>>()
|
||||
);
|
||||
|
|
@ -1092,8 +1165,11 @@ impl ArchFunctionHandler {
|
|||
}
|
||||
}
|
||||
} else {
|
||||
info!("[Tool calls]: {:?}",
|
||||
response_dict.tool_calls.iter()
|
||||
info!(
|
||||
"[Tool calls]: {:?}",
|
||||
response_dict
|
||||
.tool_calls
|
||||
.iter()
|
||||
.map(|tc| &tc.function)
|
||||
.collect::<Vec<_>>()
|
||||
);
|
||||
|
|
@ -1108,7 +1184,10 @@ impl ArchFunctionHandler {
|
|||
}
|
||||
}
|
||||
} else {
|
||||
error!("Invalid tool calls in response: {}", response_dict.error_message);
|
||||
error!(
|
||||
"Invalid tool calls in response: {}",
|
||||
response_dict.error_message
|
||||
);
|
||||
ResponseMessage {
|
||||
role: Role::Assistant,
|
||||
content: Some(String::new()),
|
||||
|
|
@ -1243,7 +1322,6 @@ pub async fn function_calling_chat_handler(
|
|||
req: Request<Incoming>,
|
||||
llm_provider_url: String,
|
||||
) -> std::result::Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||
|
||||
use hermesllm::apis::openai::ChatCompletionsRequest;
|
||||
let whole_body = req.collect().await?.to_bytes();
|
||||
|
||||
|
|
@ -1255,10 +1333,13 @@ pub async fn function_calling_chat_handler(
|
|||
let mut response = Response::new(full(
|
||||
serde_json::json!({
|
||||
"error": format!("Invalid request body: {}", e)
|
||||
}).to_string()
|
||||
})
|
||||
.to_string(),
|
||||
));
|
||||
*response.status_mut() = StatusCode::BAD_REQUEST;
|
||||
response.headers_mut().insert("Content-Type", "application/json".parse().unwrap());
|
||||
response
|
||||
.headers_mut()
|
||||
.insert("Content-Type", "application/json".parse().unwrap());
|
||||
return Ok(response);
|
||||
}
|
||||
};
|
||||
|
|
@ -1271,24 +1352,31 @@ pub async fn function_calling_chat_handler(
|
|||
// Parse as ChatCompletionsRequest
|
||||
let chat_request: ChatCompletionsRequest = match serde_json::from_value(body_json) {
|
||||
Ok(req) => {
|
||||
info!("[request body]: {}", serde_json::to_string(&req).unwrap_or_default());
|
||||
info!(
|
||||
"[request body]: {}",
|
||||
serde_json::to_string(&req).unwrap_or_default()
|
||||
);
|
||||
req
|
||||
},
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Failed to parse request body: {}", e);
|
||||
let mut response = Response::new(full(
|
||||
serde_json::json!({
|
||||
"error": format!("Invalid request body: {}", e)
|
||||
}).to_string()
|
||||
})
|
||||
.to_string(),
|
||||
));
|
||||
*response.status_mut() = StatusCode::BAD_REQUEST;
|
||||
response.headers_mut().insert("Content-Type", "application/json".parse().unwrap());
|
||||
response
|
||||
.headers_mut()
|
||||
.insert("Content-Type", "application/json".parse().unwrap());
|
||||
return Ok(response);
|
||||
}
|
||||
};
|
||||
|
||||
// Determine which handler to use based on metadata
|
||||
let use_agent_orchestrator = chat_request.metadata
|
||||
let use_agent_orchestrator = chat_request
|
||||
.metadata
|
||||
.as_ref()
|
||||
.and_then(|m| m.get("use_agent_orchestrator"))
|
||||
.and_then(|v| v.as_bool())
|
||||
|
|
@ -1309,7 +1397,10 @@ pub async fn function_calling_chat_handler(
|
|||
ARCH_FUNCTION_MODEL_NAME.to_string(),
|
||||
llm_provider_url.clone(),
|
||||
);
|
||||
handler.function_handler.function_calling_chat(chat_request).await
|
||||
handler
|
||||
.function_handler
|
||||
.function_calling_chat(chat_request)
|
||||
.await
|
||||
} else {
|
||||
let handler = ArchFunctionHandler::new(
|
||||
ARCH_FUNCTION_MODEL_NAME.to_string(),
|
||||
|
|
@ -1328,7 +1419,9 @@ pub async fn function_calling_chat_handler(
|
|||
|
||||
let mut response = Response::new(full(response_json));
|
||||
*response.status_mut() = StatusCode::OK;
|
||||
response.headers_mut().insert("Content-Type", "application/json".parse().unwrap());
|
||||
response
|
||||
.headers_mut()
|
||||
.insert("Content-Type", "application/json".parse().unwrap());
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
|
|
@ -1341,13 +1434,14 @@ pub async fn function_calling_chat_handler(
|
|||
|
||||
let mut response = Response::new(full(error_response.to_string()));
|
||||
*response.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
|
||||
response.headers_mut().insert("Content-Type", "application/json".parse().unwrap());
|
||||
response
|
||||
.headers_mut()
|
||||
.insert("Content-Type", "application/json".parse().unwrap());
|
||||
Ok(response)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// ============================================================================
|
||||
// TESTS
|
||||
// ============================================================================
|
||||
|
|
@ -1370,10 +1464,13 @@ mod tests {
|
|||
assert!(config.task_prompt.contains("</tools>\\n\\n"));
|
||||
|
||||
// Format prompt should contain literal escaped newlines and proper JSON examples
|
||||
assert!(config.format_prompt.contains("\\n\\nBased on your analysis"));
|
||||
assert!(config.format_prompt.contains(r#"{\"response\": \"Your response text here\"}"#));
|
||||
assert!(config
|
||||
.format_prompt
|
||||
.contains("\\n\\nBased on your analysis"));
|
||||
assert!(config
|
||||
.format_prompt
|
||||
.contains(r#"{\"response\": \"Your response text here\"}"#));
|
||||
assert!(config.format_prompt.contains(r#"{\"tool_calls\": [{"#));
|
||||
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -1384,7 +1481,11 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn test_fix_json_string_valid() {
|
||||
let handler = ArchFunctionHandler::new("test-model".to_string(), ArchFunctionConfig::default(), "http://localhost:8000".to_string());
|
||||
let handler = ArchFunctionHandler::new(
|
||||
"test-model".to_string(),
|
||||
ArchFunctionConfig::default(),
|
||||
"http://localhost:8000".to_string(),
|
||||
);
|
||||
let json_str = r#"{"name": "test", "value": 123}"#;
|
||||
let result = handler.fix_json_string(json_str);
|
||||
assert!(result.is_ok());
|
||||
|
|
@ -1392,7 +1493,11 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn test_fix_json_string_missing_bracket() {
|
||||
let handler = ArchFunctionHandler::new("test-model".to_string(), ArchFunctionConfig::default(), "http://localhost:8000".to_string());
|
||||
let handler = ArchFunctionHandler::new(
|
||||
"test-model".to_string(),
|
||||
ArchFunctionConfig::default(),
|
||||
"http://localhost:8000".to_string(),
|
||||
);
|
||||
let json_str = r#"{"name": "test", "value": 123"#;
|
||||
let result = handler.fix_json_string(json_str);
|
||||
assert!(result.is_ok());
|
||||
|
|
@ -1402,8 +1507,13 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn test_parse_model_response_with_tool_calls() {
|
||||
let handler = ArchFunctionHandler::new("test-model".to_string(), ArchFunctionConfig::default(), "http://localhost:8000".to_string());
|
||||
let content = r#"{"tool_calls": [{"name": "get_weather", "arguments": {"location": "NYC"}}]}"#;
|
||||
let handler = ArchFunctionHandler::new(
|
||||
"test-model".to_string(),
|
||||
ArchFunctionConfig::default(),
|
||||
"http://localhost:8000".to_string(),
|
||||
);
|
||||
let content =
|
||||
r#"{"tool_calls": [{"name": "get_weather", "arguments": {"location": "NYC"}}]}"#;
|
||||
let result = handler.parse_model_response(content);
|
||||
|
||||
assert!(result.is_valid);
|
||||
|
|
@ -1413,8 +1523,13 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn test_parse_model_response_with_clarification() {
|
||||
let handler = ArchFunctionHandler::new("test-model".to_string(), ArchFunctionConfig::default(), "http://localhost:8000".to_string());
|
||||
let content = r#"{"required_functions": ["get_weather"], "clarification": "What location?"}"#;
|
||||
let handler = ArchFunctionHandler::new(
|
||||
"test-model".to_string(),
|
||||
ArchFunctionConfig::default(),
|
||||
"http://localhost:8000".to_string(),
|
||||
);
|
||||
let content =
|
||||
r#"{"required_functions": ["get_weather"], "clarification": "What location?"}"#;
|
||||
let result = handler.parse_model_response(content);
|
||||
|
||||
assert!(result.is_valid);
|
||||
|
|
@ -1424,7 +1539,11 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn test_convert_data_type_int_to_float() {
|
||||
let handler = ArchFunctionHandler::new("test-model".to_string(), ArchFunctionConfig::default(), "http://localhost:8000".to_string());
|
||||
let handler = ArchFunctionHandler::new(
|
||||
"test-model".to_string(),
|
||||
ArchFunctionConfig::default(),
|
||||
"http://localhost:8000".to_string(),
|
||||
);
|
||||
let value = json!(42);
|
||||
let result = handler.convert_data_type(&value, "float");
|
||||
assert!(result.is_ok());
|
||||
|
|
@ -1504,13 +1623,12 @@ pub fn check_threshold(
|
|||
}
|
||||
|
||||
/// Checks if a parameter is required in the function description
|
||||
pub fn is_parameter_required(
|
||||
function_description: &Value,
|
||||
parameter_name: &str,
|
||||
) -> bool {
|
||||
pub fn is_parameter_required(function_description: &Value, parameter_name: &str) -> bool {
|
||||
if let Some(required) = function_description.get("required") {
|
||||
if let Some(required_arr) = required.as_array() {
|
||||
return required_arr.iter().any(|v| v.as_str() == Some(parameter_name));
|
||||
return required_arr
|
||||
.iter()
|
||||
.any(|v| v.as_str() == Some(parameter_name));
|
||||
}
|
||||
}
|
||||
false
|
||||
|
|
@ -1559,12 +1677,7 @@ impl HallucinationState {
|
|||
pub fn new(functions: &[Tool]) -> Self {
|
||||
let function_properties: HashMap<String, Value> = functions
|
||||
.iter()
|
||||
.map(|tool| {
|
||||
(
|
||||
tool.function.name.clone(),
|
||||
tool.function.parameters.clone(),
|
||||
)
|
||||
})
|
||||
.map(|tool| (tool.function.name.clone(), tool.function.parameters.clone()))
|
||||
.collect();
|
||||
|
||||
Self {
|
||||
|
|
@ -1620,7 +1733,10 @@ impl HallucinationState {
|
|||
|
||||
// Function name extraction logic
|
||||
if self.state.as_deref() == Some("function_name") {
|
||||
if !FUNC_NAME_END_TOKEN.iter().any(|&t| self.tokens.last().map_or(false, |tok| tok == t)) {
|
||||
if !FUNC_NAME_END_TOKEN
|
||||
.iter()
|
||||
.any(|&t| self.tokens.last().is_some_and(|tok| tok == t))
|
||||
{
|
||||
self.mask.push(MaskToken::FunctionName);
|
||||
} else {
|
||||
self.state = None;
|
||||
|
|
@ -1629,34 +1745,51 @@ impl HallucinationState {
|
|||
}
|
||||
|
||||
// Check for function name start
|
||||
if FUNC_NAME_START_PATTERN.iter().any(|&p| content.ends_with(p)) {
|
||||
if FUNC_NAME_START_PATTERN
|
||||
.iter()
|
||||
.any(|&p| content.ends_with(p))
|
||||
{
|
||||
self.state = Some("function_name".to_string());
|
||||
}
|
||||
|
||||
// Parameter name extraction logic
|
||||
if self.state.as_deref() == Some("parameter_name")
|
||||
&& !PARAMETER_NAME_END_TOKENS.iter().any(|&t| content.ends_with(t)) {
|
||||
&& !PARAMETER_NAME_END_TOKENS
|
||||
.iter()
|
||||
.any(|&t| content.ends_with(t))
|
||||
{
|
||||
self.mask.push(MaskToken::ParameterName);
|
||||
} else if self.state.as_deref() == Some("parameter_name")
|
||||
&& PARAMETER_NAME_END_TOKENS.iter().any(|&t| content.ends_with(t)) {
|
||||
&& PARAMETER_NAME_END_TOKENS
|
||||
.iter()
|
||||
.any(|&t| content.ends_with(t))
|
||||
{
|
||||
self.state = None;
|
||||
self.parameter_name_done = true;
|
||||
self.get_parameter_name();
|
||||
} else if self.parameter_name_done
|
||||
&& !self.open_bracket
|
||||
&& PARAMETER_NAME_START_PATTERN.iter().any(|&p| content.ends_with(p)) {
|
||||
&& PARAMETER_NAME_START_PATTERN
|
||||
.iter()
|
||||
.any(|&p| content.ends_with(p))
|
||||
{
|
||||
self.state = Some("parameter_name".to_string());
|
||||
}
|
||||
|
||||
// First parameter value start
|
||||
if FIRST_PARAM_NAME_START_PATTERN.iter().any(|&p| content.ends_with(p)) {
|
||||
if FIRST_PARAM_NAME_START_PATTERN
|
||||
.iter()
|
||||
.any(|&p| content.ends_with(p))
|
||||
{
|
||||
self.state = Some("parameter_name".to_string());
|
||||
}
|
||||
|
||||
// Parameter value extraction logic
|
||||
if self.state.as_deref() == Some("parameter_value")
|
||||
&& !PARAMETER_VALUE_END_TOKEN.iter().any(|&t| content.ends_with(t)) {
|
||||
|
||||
&& !PARAMETER_VALUE_END_TOKEN
|
||||
.iter()
|
||||
.any(|&t| content.ends_with(t))
|
||||
{
|
||||
// Check for brackets
|
||||
if let Some(last_token) = self.tokens.last() {
|
||||
let open_brackets: Vec<char> = last_token
|
||||
|
|
@ -1694,8 +1827,11 @@ impl HallucinationState {
|
|||
&& self.mask[self.mask.len() - 2] != MaskToken::ParameterValue
|
||||
&& !self.parameter_name.is_empty()
|
||||
{
|
||||
let last_param = self.parameter_name[self.parameter_name.len() - 1].clone();
|
||||
if let Some(func_props) = self.function_properties.get(&self.function_name) {
|
||||
let last_param =
|
||||
self.parameter_name[self.parameter_name.len() - 1].clone();
|
||||
if let Some(func_props) =
|
||||
self.function_properties.get(&self.function_name)
|
||||
{
|
||||
if is_parameter_required(func_props, &last_param)
|
||||
&& !is_parameter_property(func_props, &last_param, "enum")
|
||||
&& !self.check_parameter_name.contains_key(&last_param)
|
||||
|
|
@ -1718,10 +1854,16 @@ impl HallucinationState {
|
|||
}
|
||||
} else if self.state.as_deref() == Some("parameter_value")
|
||||
&& !self.open_bracket
|
||||
&& PARAMETER_VALUE_END_TOKEN.iter().any(|&t| content.ends_with(t)) {
|
||||
&& PARAMETER_VALUE_END_TOKEN
|
||||
.iter()
|
||||
.any(|&t| content.ends_with(t))
|
||||
{
|
||||
self.state = None;
|
||||
} else if self.parameter_name_done
|
||||
&& PARAMETER_VALUE_START_PATTERN.iter().any(|&p| content.ends_with(p)) {
|
||||
&& PARAMETER_VALUE_START_PATTERN
|
||||
.iter()
|
||||
.any(|&p| content.ends_with(p))
|
||||
{
|
||||
self.state = Some("parameter_value".to_string());
|
||||
}
|
||||
|
||||
|
|
@ -1848,18 +1990,18 @@ mod hallucination_tests {
|
|||
let handler = ArchFunctionHandler::new(
|
||||
"test-model".to_string(),
|
||||
ArchFunctionConfig::default(),
|
||||
"http://localhost:8000".to_string()
|
||||
"http://localhost:8000".to_string(),
|
||||
);
|
||||
|
||||
// Test integer types
|
||||
assert!(handler.check_value_type(&json!(42), "integer"));
|
||||
assert!(handler.check_value_type(&json!(42), "int"));
|
||||
assert!(!handler.check_value_type(&json!(3.14), "integer"));
|
||||
assert!(!handler.check_value_type(&json!(3.15), "integer"));
|
||||
|
||||
// Test number types (accepts both int and float)
|
||||
assert!(handler.check_value_type(&json!(3.14), "number"));
|
||||
assert!(handler.check_value_type(&json!(3.15), "number"));
|
||||
assert!(handler.check_value_type(&json!(42), "number"));
|
||||
assert!(handler.check_value_type(&json!(3.14), "float"));
|
||||
assert!(handler.check_value_type(&json!(3.15), "float"));
|
||||
|
||||
// Test boolean
|
||||
assert!(handler.check_value_type(&json!(true), "boolean"));
|
||||
|
|
@ -1890,12 +2032,16 @@ mod hallucination_tests {
|
|||
let handler = ArchFunctionHandler::new(
|
||||
"test-model".to_string(),
|
||||
ArchFunctionConfig::default(),
|
||||
"http://localhost:8000".to_string()
|
||||
"http://localhost:8000".to_string(),
|
||||
);
|
||||
|
||||
// Test valid type - no conversion needed
|
||||
assert!(handler.validate_or_convert_parameter(&json!(42), "integer").unwrap());
|
||||
assert!(handler.validate_or_convert_parameter(&json!("hello"), "string").unwrap());
|
||||
assert!(handler
|
||||
.validate_or_convert_parameter(&json!(42), "integer")
|
||||
.unwrap());
|
||||
assert!(handler
|
||||
.validate_or_convert_parameter(&json!("hello"), "string")
|
||||
.unwrap());
|
||||
|
||||
// Test integer to float conversion (convert_data_type supports this)
|
||||
let result = handler.validate_or_convert_parameter(&json!(42), "float");
|
||||
|
|
@ -1910,8 +2056,12 @@ mod hallucination_tests {
|
|||
assert!(!result.unwrap());
|
||||
|
||||
// Test number accepting both int and float
|
||||
assert!(handler.validate_or_convert_parameter(&json!(42), "number").unwrap());
|
||||
assert!(handler.validate_or_convert_parameter(&json!(3.14), "number").unwrap());
|
||||
assert!(handler
|
||||
.validate_or_convert_parameter(&json!(42), "number")
|
||||
.unwrap());
|
||||
assert!(handler
|
||||
.validate_or_convert_parameter(&json!(3.15), "number")
|
||||
.unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue