cargo clippy (#660)

This commit is contained in:
Adil Hafeez 2025-12-25 21:08:37 -08:00 committed by GitHub
parent c75e7606f9
commit ca95ffb63d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
62 changed files with 1864 additions and 1187 deletions

View file

@ -3,7 +3,7 @@ use std::time::{Instant, SystemTime};
use bytes::Bytes;
use common::consts::TRACE_PARENT_HEADER;
use common::traces::{SpanBuilder, SpanKind, parse_traceparent, generate_random_span_id};
use common::traces::{generate_random_span_id, parse_traceparent, SpanBuilder, SpanKind};
use hermesllm::apis::OpenAIMessage;
use hermesllm::clients::SupportedAPIsFromClient;
use hermesllm::providers::request::ProviderRequest;
@ -18,7 +18,7 @@ use super::agent_selector::{AgentSelectionError, AgentSelector};
use super::pipeline_processor::{PipelineError, PipelineProcessor};
use super::response_handler::ResponseHandler;
use crate::router::plano_orchestrator::OrchestratorService;
use crate::tracing::{OperationNameBuilder, operation_component, http};
use crate::tracing::{http, operation_component, OperationNameBuilder};
/// Main errors for agent chat completions
#[derive(Debug, thiserror::Error)]
@ -61,7 +61,6 @@ pub async fn agent_chat(
body,
}) = &err
{
warn!(
"Client error from agent '{}' (HTTP {}): {}",
agent, status, body
@ -77,8 +76,8 @@ pub async fn agent_chat(
let json_string = error_json.to_string();
let mut response = Response::new(ResponseHandler::create_full_body(json_string));
*response.status_mut() = hyper::StatusCode::from_u16(*status)
.unwrap_or(hyper::StatusCode::BAD_REQUEST);
*response.status_mut() =
hyper::StatusCode::from_u16(*status).unwrap_or(hyper::StatusCode::BAD_REQUEST);
response.headers_mut().insert(
hyper::header::CONTENT_TYPE,
"application/json".parse().unwrap(),
@ -234,8 +233,18 @@ async fn handle_agent_chat(
.with_attribute(http::TARGET, "/agents/select")
.with_attribute("selection.listener", listener.name.clone())
.with_attribute("selection.agent_count", selected_agents.len().to_string())
.with_attribute("selection.agents", selected_agents.iter().map(|a| a.id.as_str()).collect::<Vec<_>>().join(","))
.with_attribute("duration_ms", format!("{:.2}", selection_elapsed.as_secs_f64() * 1000.0));
.with_attribute(
"selection.agents",
selected_agents
.iter()
.map(|a| a.id.as_str())
.collect::<Vec<_>>()
.join(","),
)
.with_attribute(
"duration_ms",
format!("{:.2}", selection_elapsed.as_secs_f64() * 1000.0),
);
if !trace_id.is_empty() {
selection_span_builder = selection_span_builder.with_trace_id(trace_id.clone());
@ -318,8 +327,14 @@ async fn handle_agent_chat(
.with_attribute(http::METHOD, "POST")
.with_attribute(http::TARGET, full_path)
.with_attribute("agent.name", agent_name.clone())
.with_attribute("agent.sequence", format!("{}/{}", agent_index + 1, agent_count))
.with_attribute("duration_ms", format!("{:.2}", agent_elapsed.as_secs_f64() * 1000.0));
.with_attribute(
"agent.sequence",
format!("{}/{}", agent_index + 1, agent_count),
)
.with_attribute(
"duration_ms",
format!("{:.2}", agent_elapsed.as_secs_f64() * 1000.0),
);
if !trace_id.is_empty() {
span_builder = span_builder.with_trace_id(trace_id.clone());
@ -333,7 +348,10 @@ async fn handle_agent_chat(
// If this is the last agent, return the streaming response
if is_last_agent {
info!("Completed agent chain, returning response from last agent: {}", agent_name);
info!(
"Completed agent chain, returning response from last agent: {}",
agent_name
);
return response_handler
.create_streaming_response(llm_response)
.await
@ -341,7 +359,10 @@ async fn handle_agent_chat(
}
// For intermediate agents, collect the full response and pass to next agent
debug!("Collecting response from intermediate agent: {}", agent_name);
debug!(
"Collecting response from intermediate agent: {}",
agent_name
);
let response_text = response_handler.collect_full_response(llm_response).await?;
info!(
@ -364,7 +385,6 @@ async fn handle_agent_chat(
});
current_messages.push(last_message);
}
// This should never be reached since we return in the last agent iteration

View file

@ -2,7 +2,7 @@ use std::collections::HashMap;
use std::sync::Arc;
use common::configuration::{
Agent, AgentFilterChain, Listener, AgentUsagePreference, OrchestrationPreference,
Agent, AgentFilterChain, AgentUsagePreference, Listener, OrchestrationPreference,
};
use hermesllm::apis::openai::Message;
use tracing::{debug, warn};

View file

@ -1,20 +1,18 @@
use bytes::Bytes;
use eventsource_stream::Eventsource;
use futures::StreamExt;
use hermesllm::apis::openai::{
ChatCompletionsRequest, ChatCompletionsResponse, Choice, FinishReason, FunctionCall, Message,
MessageContent, ResponseMessage, Role, Tool, ToolCall, Usage,
};
use http_body_util::{combinators::BoxBody, BodyExt, Full};
use hyper::body::Incoming;
use hyper::{Request, Response, StatusCode};
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use std::collections::HashMap;
use thiserror::Error;
use tracing::{info, error};
use futures::StreamExt;
use bytes::Bytes;
use http_body_util::{combinators::BoxBody, BodyExt, Full};
use hyper::body::Incoming;
use hyper::{Request, Response, StatusCode};
use eventsource_stream::Eventsource;
use tracing::{error, info};
// ============================================================================
// CONSTANTS FOR HALLUCINATION DETECTION
@ -273,17 +271,14 @@ impl ArchFunctionHandler {
let mut stack: Vec<char> = Vec::new();
let mut fixed_str = String::new();
let matching_bracket: HashMap<char, char> =
[(')', '('), ('}', '{'), (']', '[')]
.iter()
.cloned()
.collect();
let opening_bracket: HashMap<char, char> = matching_bracket
let matching_bracket: HashMap<char, char> = [(')', '('), ('}', '{'), (']', '[')]
.iter()
.map(|(k, v)| (*v, *k))
.cloned()
.collect();
let opening_bracket: HashMap<char, char> =
matching_bracket.iter().map(|(k, v)| (*v, *k)).collect();
for ch in json_str.chars() {
if ch == '{' || ch == '[' || ch == '(' {
stack.push(ch);
@ -332,12 +327,18 @@ impl ArchFunctionHandler {
// Remove markdown code blocks
let mut content = content.trim().to_string();
if content.starts_with("```") && content.ends_with("```") {
content = content.trim_start_matches("```").trim_end_matches("```").to_string();
content = content
.trim_start_matches("```")
.trim_end_matches("```")
.to_string();
if content.starts_with("json") {
content = content.trim_start_matches("json").to_string();
}
// Trim again after removing code blocks to eliminate internal whitespace
content = content.trim_start_matches(r"\n").trim_end_matches(r"\n").to_string();
content = content
.trim_start_matches(r"\n")
.trim_end_matches(r"\n")
.to_string();
content = content.trim().to_string();
// Unescape the quotes: \" -> "
// The model sometimes returns escaped JSON inside markdown blocks
@ -453,12 +454,12 @@ impl ArchFunctionHandler {
/// Helper method to check if a value matches the expected type
fn check_value_type(&self, value: &Value, target_type: &str) -> bool {
match target_type {
"int" | "integer" => value.is_i64() || value.is_u64(),
"int" | "integer" => value.is_i64() || value.is_u64(),
"float" | "number" => value.is_f64() || value.is_i64() || value.is_u64(),
"bool" | "boolean" => value.is_boolean(),
"str" | "string" => value.is_string(),
"list" | "array" => value.is_array(),
"dict" | "object" => value.is_object(),
"bool" | "boolean" => value.is_boolean(),
"str" | "string" => value.is_string(),
"list" | "array" => value.is_array(),
"dict" | "object" => value.is_object(),
_ => true,
}
}
@ -505,15 +506,19 @@ impl ArchFunctionHandler {
let func_name = &tool_call.function.name;
// Parse arguments as JSON
let func_args: HashMap<String, Value> = match serde_json::from_str(&tool_call.function.arguments) {
Ok(args) => args,
Err(e) => {
verification.is_valid = false;
verification.invalid_tool_call = Some(tool_call.clone());
verification.error_message = format!("Failed to parse arguments for function '{}': {}", func_name, e);
break;
}
};
let func_args: HashMap<String, Value> =
match serde_json::from_str(&tool_call.function.arguments) {
Ok(args) => args,
Err(e) => {
verification.is_valid = false;
verification.invalid_tool_call = Some(tool_call.clone());
verification.error_message = format!(
"Failed to parse arguments for function '{}': {}",
func_name, e
);
break;
}
};
// Check if function is available
if let Some(function_params) = functions.get(func_name) {
@ -541,14 +546,23 @@ impl ArchFunctionHandler {
if let Some(properties_obj) = properties.as_object() {
for (param_name, param_value) in &func_args {
if let Some(param_schema) = properties_obj.get(param_name) {
if let Some(target_type) = param_schema.get("type").and_then(|v| v.as_str()) {
if self.config.support_data_types.contains(&target_type.to_string()) {
if let Some(target_type) =
param_schema.get("type").and_then(|v| v.as_str())
{
if self
.config
.support_data_types
.contains(&target_type.to_string())
{
// Validate data type using helper method
match self.validate_or_convert_parameter(param_value, target_type) {
match self
.validate_or_convert_parameter(param_value, target_type)
{
Ok(is_valid) => {
if !is_valid {
verification.is_valid = false;
verification.invalid_tool_call = Some(tool_call.clone());
verification.invalid_tool_call =
Some(tool_call.clone());
verification.error_message = format!(
"Parameter `{}` is expected to have the data type `{}`, got incompatible type.",
param_name, target_type
@ -558,7 +572,8 @@ impl ArchFunctionHandler {
}
Err(_) => {
verification.is_valid = false;
verification.invalid_tool_call = Some(tool_call.clone());
verification.invalid_tool_call =
Some(tool_call.clone());
verification.error_message = format!(
"Parameter `{}` is expected to have the data type `{}`, got incompatible type.",
param_name, target_type
@ -569,7 +584,10 @@ impl ArchFunctionHandler {
} else {
verification.is_valid = false;
verification.invalid_tool_call = Some(tool_call.clone());
verification.error_message = format!("Data type `{}` is not supported.", target_type);
verification.error_message = format!(
"Data type `{}` is not supported.",
target_type
);
break;
}
}
@ -598,11 +616,8 @@ impl ArchFunctionHandler {
/// Formats the system prompt with tools
pub fn format_system_prompt(&self, tools: &[Tool]) -> Result<String> {
let tools_str = self.convert_tools(tools)?;
let system_prompt = self
.config
.task_prompt
.replace("{tools}", &tools_str)
+ &self.config.format_prompt;
let system_prompt =
self.config.task_prompt.replace("{tools}", &tools_str) + &self.config.format_prompt;
Ok(system_prompt)
}
@ -665,15 +680,22 @@ impl ArchFunctionHandler {
// Strip markdown code blocks
if tool_call_msg.starts_with("```") && tool_call_msg.ends_with("```") {
tool_call_msg = tool_call_msg.trim_start_matches("```").trim_end_matches("```").trim().to_string();
tool_call_msg = tool_call_msg
.trim_start_matches("```")
.trim_end_matches("```")
.trim()
.to_string();
if tool_call_msg.starts_with("json") {
tool_call_msg = tool_call_msg.trim_start_matches("json").trim().to_string();
tool_call_msg =
tool_call_msg.trim_start_matches("json").trim().to_string();
}
}
// Extract function name
if let Ok(parsed) = serde_json::from_str::<Value>(&tool_call_msg) {
if let Some(tool_calls_arr) = parsed.get("tool_calls").and_then(|v| v.as_array()) {
if let Some(tool_calls_arr) =
parsed.get("tool_calls").and_then(|v| v.as_array())
{
if let Some(first_tool_call) = tool_calls_arr.first() {
let func_name = first_tool_call
.get("name")
@ -685,8 +707,10 @@ impl ArchFunctionHandler {
"result": content,
});
content = format!("<tool_response>\n{}\n</tool_response>",
serde_json::to_string(&tool_response)?);
content = format!(
"<tool_response>\n{}\n</tool_response>",
serde_json::to_string(&tool_response)?
);
}
}
}
@ -717,7 +741,7 @@ impl ArchFunctionHandler {
if let Some(instruction) = extra_instruction {
if let Some(last) = processed_messages.last_mut() {
if let MessageContent::Text(content) = &mut last.content {
content.push_str("\n");
content.push('\n');
content.push_str(instruction);
}
}
@ -750,13 +774,11 @@ impl ArchFunctionHandler {
for i in (conversation_idx..messages.len()).rev() {
if let MessageContent::Text(content) = &messages[i].content {
num_tokens += content.len() / 4;
if num_tokens >= max_tokens {
if messages[i].role == Role::User {
// Set message_idx to current position and break
// This matches Python's behavior where message_idx is set before break
message_idx = i;
break;
}
if num_tokens >= max_tokens && messages[i].role == Role::User {
// Set message_idx to current position and break
// This matches Python's behavior where message_idx is set before break
message_idx = i;
break;
}
}
// Only update message_idx if we haven't hit the token limit yet
@ -789,7 +811,11 @@ impl ArchFunctionHandler {
}
/// Helper to create a request with VLLM-specific parameters
fn create_request_with_extra_body(&self, messages: Vec<Message>, stream: bool) -> ChatCompletionsRequest {
fn create_request_with_extra_body(
&self,
messages: Vec<Message>,
stream: bool,
) -> ChatCompletionsRequest {
ChatCompletionsRequest {
model: self.model_name.clone(),
messages,
@ -813,24 +839,38 @@ impl ArchFunctionHandler {
}
/// Makes a streaming request and returns the SSE event stream
async fn make_streaming_request(&self, request: ChatCompletionsRequest) -> Result<std::pin::Pin<Box<dyn futures::Stream<Item = std::result::Result<Value, String>> + Send>>> {
let request_body = serde_json::to_string(&request)
.map_err(|e| FunctionCallingError::InvalidModelResponse(format!("Failed to serialize request: {}", e)))?;
async fn make_streaming_request(
&self,
request: ChatCompletionsRequest,
) -> Result<
std::pin::Pin<Box<dyn futures::Stream<Item = std::result::Result<Value, String>> + Send>>,
> {
let request_body = serde_json::to_string(&request).map_err(|e| {
FunctionCallingError::InvalidModelResponse(format!(
"Failed to serialize request: {}",
e
))
})?;
let response = self.http_client
let response = self
.http_client
.post(&self.endpoint_url)
.header("Content-Type", "application/json")
.body(request_body)
.send()
.await
.map_err(|e| FunctionCallingError::HttpError(e))?;
.map_err(FunctionCallingError::HttpError)?;
if !response.status().is_success() {
let status = response.status();
let error_text = response.text().await.unwrap_or_else(|_| "Unknown error".to_string());
return Err(FunctionCallingError::InvalidModelResponse(
format!("HTTP error {}: {}", status, error_text)
));
let error_text = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
return Err(FunctionCallingError::InvalidModelResponse(format!(
"HTTP error {}: {}",
status, error_text
)));
}
// Parse SSE stream
@ -856,38 +896,51 @@ impl ArchFunctionHandler {
}
/// Makes a non-streaming request and returns the response
async fn make_non_streaming_request(&self, request: ChatCompletionsRequest) -> Result<ChatCompletionsResponse> {
let request_body = serde_json::to_string(&request)
.map_err(|e| FunctionCallingError::InvalidModelResponse(format!("Failed to serialize request: {}", e)))?;
async fn make_non_streaming_request(
&self,
request: ChatCompletionsRequest,
) -> Result<ChatCompletionsResponse> {
let request_body = serde_json::to_string(&request).map_err(|e| {
FunctionCallingError::InvalidModelResponse(format!(
"Failed to serialize request: {}",
e
))
})?;
let response = self.http_client
let response = self
.http_client
.post(&self.endpoint_url)
.header("Content-Type", "application/json")
.body(request_body)
.send()
.await
.map_err(|e| FunctionCallingError::HttpError(e))?;
.map_err(FunctionCallingError::HttpError)?;
if !response.status().is_success() {
let status = response.status();
let error_text = response.text().await.unwrap_or_else(|_| "Unknown error".to_string());
return Err(FunctionCallingError::InvalidModelResponse(
format!("HTTP error {}: {}", status, error_text)
));
let error_text = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
return Err(FunctionCallingError::InvalidModelResponse(format!(
"HTTP error {}: {}",
status, error_text
)));
}
let response_text = response.text().await
.map_err(|e| FunctionCallingError::HttpError(e))?;
let response_text = response
.text()
.await
.map_err(FunctionCallingError::HttpError)?;
serde_json::from_str(&response_text)
.map_err(|e| FunctionCallingError::JsonParseError(e))
serde_json::from_str(&response_text).map_err(FunctionCallingError::JsonParseError)
}
pub async fn function_calling_chat(
&self,
request: ChatCompletionsRequest,
) -> Result<ChatCompletionsResponse> {
use tracing::{info, error};
use tracing::{error, info};
info!("[Arch-Function] - ChatCompletion");
@ -899,10 +952,14 @@ impl ArchFunctionHandler {
request.metadata.as_ref(),
)?;
info!("[request to arch-fc]: model: {}, messages count: {}",
self.model_name, messages.len());
info!(
"[request to arch-fc]: model: {}, messages count: {}",
self.model_name,
messages.len()
);
let use_agent_orchestrator = request.metadata
let use_agent_orchestrator = request
.metadata
.as_ref()
.and_then(|m| m.get("use_agent_orchestrator"))
.and_then(|v| v.as_bool())
@ -918,89 +975,95 @@ impl ArchFunctionHandler {
if use_agent_orchestrator {
while let Some(chunk_result) = stream.next().await {
let chunk = chunk_result.map_err(|e| FunctionCallingError::InvalidModelResponse(e))?;
let chunk = chunk_result.map_err(FunctionCallingError::InvalidModelResponse)?;
// Extract content from JSON response
if let Some(choices) = chunk.get("choices").and_then(|v| v.as_array()) {
if let Some(choice) = choices.first() {
if let Some(content) = choice.get("delta")
if let Some(content) = choice
.get("delta")
.and_then(|d| d.get("content"))
.and_then(|c| c.as_str()) {
.and_then(|c| c.as_str())
{
model_response.push_str(content);
}
}
}
}
info!("[Agent Orchestrator]: response received");
} else {
if let Some(tools) = request.tools.as_ref() {
let mut hallucination_state = HallucinationState::new(tools);
let mut has_tool_calls = None;
let mut has_hallucination = false;
} else if let Some(tools) = request.tools.as_ref() {
let mut hallucination_state = HallucinationState::new(tools);
let mut has_tool_calls = None;
let mut has_hallucination = false;
while let Some(chunk_result) = stream.next().await {
let chunk = chunk_result.map_err(|e| FunctionCallingError::InvalidModelResponse(e))?;
while let Some(chunk_result) = stream.next().await {
let chunk = chunk_result.map_err(FunctionCallingError::InvalidModelResponse)?;
// Extract content and logprobs from JSON response
if let Some(choices) = chunk.get("choices").and_then(|v| v.as_array()) {
if let Some(choice) = choices.first() {
if let Some(content) = choice.get("delta")
.and_then(|d| d.get("content"))
.and_then(|c| c.as_str()) {
// Extract content and logprobs from JSON response
if let Some(choices) = chunk.get("choices").and_then(|v| v.as_array()) {
if let Some(choice) = choices.first() {
if let Some(content) = choice
.get("delta")
.and_then(|d| d.get("content"))
.and_then(|c| c.as_str())
{
// Extract logprobs
let logprobs: Vec<f64> = choice
.get("logprobs")
.and_then(|lp| lp.get("content"))
.and_then(|c| c.as_array())
.and_then(|arr| arr.first())
.and_then(|token| token.get("top_logprobs"))
.and_then(|tlp| tlp.as_array())
.map(|arr| {
arr.iter()
.filter_map(|v| v.get("logprob").and_then(|lp| lp.as_f64()))
.collect()
})
.unwrap_or_default();
// Extract logprobs
let logprobs: Vec<f64> = choice.get("logprobs")
.and_then(|lp| lp.get("content"))
.and_then(|c| c.as_array())
.and_then(|arr| arr.first())
.and_then(|token| token.get("top_logprobs"))
.and_then(|tlp| tlp.as_array())
.map(|arr| {
arr.iter()
.filter_map(|v| v.get("logprob").and_then(|lp| lp.as_f64()))
.collect()
})
.unwrap_or_default();
if hallucination_state
.append_and_check_token_hallucination(content.to_string(), logprobs)
{
has_hallucination = true;
break;
}
if hallucination_state.append_and_check_token_hallucination(content.to_string(), logprobs) {
has_hallucination = true;
break;
}
if hallucination_state.tokens.len() > 5 && has_tool_calls.is_none() {
let collected_content = hallucination_state.tokens.join("");
has_tool_calls = Some(collected_content.contains("tool_calls"));
}
if hallucination_state.tokens.len() > 5 && has_tool_calls.is_none() {
let collected_content = hallucination_state.tokens.join("");
has_tool_calls = Some(collected_content.contains("tool_calls"));
}
}
}
}
}
if has_tool_calls == Some(true) && has_hallucination {
info!("[Hallucination]: {}", hallucination_state.error_message);
if has_tool_calls == Some(true) && has_hallucination {
info!("[Hallucination]: {}", hallucination_state.error_message);
let clarify_messages = self.prefill_message(messages.clone(), &self.clarify_prefix);
let clarify_request = self.create_request_with_extra_body(clarify_messages, false);
let clarify_messages = self.prefill_message(messages.clone(), &self.clarify_prefix);
let clarify_request = self.create_request_with_extra_body(clarify_messages, false);
let retry_response = self.make_non_streaming_request(clarify_request).await?;
let retry_response = self.make_non_streaming_request(clarify_request).await?;
if let Some(choice) = retry_response.choices.first() {
if let Some(content) = &choice.message.content {
model_response = content.clone();
}
if let Some(choice) = retry_response.choices.first() {
if let Some(content) = &choice.message.content {
model_response = content.clone();
}
} else {
model_response = hallucination_state.tokens.join("");
}
} else {
while let Some(chunk_result) = stream.next().await {
let chunk = chunk_result.map_err(|e| FunctionCallingError::InvalidModelResponse(e))?;
if let Some(choices) = chunk.get("choices").and_then(|v| v.as_array()) {
if let Some(choice) = choices.first() {
if let Some(content) = choice.get("delta")
.and_then(|d| d.get("content"))
.and_then(|c| c.as_str()) {
model_response.push_str(content);
}
model_response = hallucination_state.tokens.join("");
}
} else {
while let Some(chunk_result) = stream.next().await {
let chunk = chunk_result.map_err(FunctionCallingError::InvalidModelResponse)?;
if let Some(choices) = chunk.get("choices").and_then(|v| v.as_array()) {
if let Some(choice) = choices.first() {
if let Some(content) = choice
.get("delta")
.and_then(|d| d.get("content"))
.and_then(|c| c.as_str())
{
model_response.push_str(content);
}
}
}
@ -1009,10 +1072,17 @@ impl ArchFunctionHandler {
let response_dict = self.parse_model_response(&model_response);
info!("[arch-fc]: raw model response: {}", response_dict.raw_response);
info!(
"[arch-fc]: raw model response: {}",
response_dict.raw_response
);
// General model response (no intent matched - should route to default target)
let model_message = if response_dict.response.as_ref().map_or(false, |s| !s.is_empty()) {
let model_message = if response_dict
.response
.as_ref()
.is_some_and(|s| !s.is_empty())
{
// When arch-fc returns a "response" field, it means no intent was matched
// Return empty content and empty tool_calls so prompt_gateway routes to default target
ResponseMessage {
@ -1053,8 +1123,11 @@ impl ArchFunctionHandler {
let verification = self.verify_tool_calls(tools, &response_dict.tool_calls);
if verification.is_valid {
info!("[Tool calls]: {:?}",
response_dict.tool_calls.iter()
info!(
"[Tool calls]: {:?}",
response_dict
.tool_calls
.iter()
.map(|tc| &tc.function)
.collect::<Vec<_>>()
);
@ -1092,8 +1165,11 @@ impl ArchFunctionHandler {
}
}
} else {
info!("[Tool calls]: {:?}",
response_dict.tool_calls.iter()
info!(
"[Tool calls]: {:?}",
response_dict
.tool_calls
.iter()
.map(|tc| &tc.function)
.collect::<Vec<_>>()
);
@ -1108,7 +1184,10 @@ impl ArchFunctionHandler {
}
}
} else {
error!("Invalid tool calls in response: {}", response_dict.error_message);
error!(
"Invalid tool calls in response: {}",
response_dict.error_message
);
ResponseMessage {
role: Role::Assistant,
content: Some(String::new()),
@ -1243,7 +1322,6 @@ pub async fn function_calling_chat_handler(
req: Request<Incoming>,
llm_provider_url: String,
) -> std::result::Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
use hermesllm::apis::openai::ChatCompletionsRequest;
let whole_body = req.collect().await?.to_bytes();
@ -1255,10 +1333,13 @@ pub async fn function_calling_chat_handler(
let mut response = Response::new(full(
serde_json::json!({
"error": format!("Invalid request body: {}", e)
}).to_string()
})
.to_string(),
));
*response.status_mut() = StatusCode::BAD_REQUEST;
response.headers_mut().insert("Content-Type", "application/json".parse().unwrap());
response
.headers_mut()
.insert("Content-Type", "application/json".parse().unwrap());
return Ok(response);
}
};
@ -1271,24 +1352,31 @@ pub async fn function_calling_chat_handler(
// Parse as ChatCompletionsRequest
let chat_request: ChatCompletionsRequest = match serde_json::from_value(body_json) {
Ok(req) => {
info!("[request body]: {}", serde_json::to_string(&req).unwrap_or_default());
info!(
"[request body]: {}",
serde_json::to_string(&req).unwrap_or_default()
);
req
},
}
Err(e) => {
error!("Failed to parse request body: {}", e);
let mut response = Response::new(full(
serde_json::json!({
"error": format!("Invalid request body: {}", e)
}).to_string()
})
.to_string(),
));
*response.status_mut() = StatusCode::BAD_REQUEST;
response.headers_mut().insert("Content-Type", "application/json".parse().unwrap());
response
.headers_mut()
.insert("Content-Type", "application/json".parse().unwrap());
return Ok(response);
}
};
// Determine which handler to use based on metadata
let use_agent_orchestrator = chat_request.metadata
let use_agent_orchestrator = chat_request
.metadata
.as_ref()
.and_then(|m| m.get("use_agent_orchestrator"))
.and_then(|v| v.as_bool())
@ -1309,7 +1397,10 @@ pub async fn function_calling_chat_handler(
ARCH_FUNCTION_MODEL_NAME.to_string(),
llm_provider_url.clone(),
);
handler.function_handler.function_calling_chat(chat_request).await
handler
.function_handler
.function_calling_chat(chat_request)
.await
} else {
let handler = ArchFunctionHandler::new(
ARCH_FUNCTION_MODEL_NAME.to_string(),
@ -1328,7 +1419,9 @@ pub async fn function_calling_chat_handler(
let mut response = Response::new(full(response_json));
*response.status_mut() = StatusCode::OK;
response.headers_mut().insert("Content-Type", "application/json".parse().unwrap());
response
.headers_mut()
.insert("Content-Type", "application/json".parse().unwrap());
Ok(response)
}
@ -1341,13 +1434,14 @@ pub async fn function_calling_chat_handler(
let mut response = Response::new(full(error_response.to_string()));
*response.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
response.headers_mut().insert("Content-Type", "application/json".parse().unwrap());
response
.headers_mut()
.insert("Content-Type", "application/json".parse().unwrap());
Ok(response)
}
}
}
// ============================================================================
// TESTS
// ============================================================================
@ -1370,10 +1464,13 @@ mod tests {
assert!(config.task_prompt.contains("</tools>\\n\\n"));
// Format prompt should contain literal escaped newlines and proper JSON examples
assert!(config.format_prompt.contains("\\n\\nBased on your analysis"));
assert!(config.format_prompt.contains(r#"{\"response\": \"Your response text here\"}"#));
assert!(config
.format_prompt
.contains("\\n\\nBased on your analysis"));
assert!(config
.format_prompt
.contains(r#"{\"response\": \"Your response text here\"}"#));
assert!(config.format_prompt.contains(r#"{\"tool_calls\": [{"#));
}
#[test]
@ -1384,7 +1481,11 @@ mod tests {
#[test]
fn test_fix_json_string_valid() {
let handler = ArchFunctionHandler::new("test-model".to_string(), ArchFunctionConfig::default(), "http://localhost:8000".to_string());
let handler = ArchFunctionHandler::new(
"test-model".to_string(),
ArchFunctionConfig::default(),
"http://localhost:8000".to_string(),
);
let json_str = r#"{"name": "test", "value": 123}"#;
let result = handler.fix_json_string(json_str);
assert!(result.is_ok());
@ -1392,7 +1493,11 @@ mod tests {
#[test]
fn test_fix_json_string_missing_bracket() {
let handler = ArchFunctionHandler::new("test-model".to_string(), ArchFunctionConfig::default(), "http://localhost:8000".to_string());
let handler = ArchFunctionHandler::new(
"test-model".to_string(),
ArchFunctionConfig::default(),
"http://localhost:8000".to_string(),
);
let json_str = r#"{"name": "test", "value": 123"#;
let result = handler.fix_json_string(json_str);
assert!(result.is_ok());
@ -1402,8 +1507,13 @@ mod tests {
#[test]
fn test_parse_model_response_with_tool_calls() {
let handler = ArchFunctionHandler::new("test-model".to_string(), ArchFunctionConfig::default(), "http://localhost:8000".to_string());
let content = r#"{"tool_calls": [{"name": "get_weather", "arguments": {"location": "NYC"}}]}"#;
let handler = ArchFunctionHandler::new(
"test-model".to_string(),
ArchFunctionConfig::default(),
"http://localhost:8000".to_string(),
);
let content =
r#"{"tool_calls": [{"name": "get_weather", "arguments": {"location": "NYC"}}]}"#;
let result = handler.parse_model_response(content);
assert!(result.is_valid);
@ -1413,8 +1523,13 @@ mod tests {
#[test]
fn test_parse_model_response_with_clarification() {
let handler = ArchFunctionHandler::new("test-model".to_string(), ArchFunctionConfig::default(), "http://localhost:8000".to_string());
let content = r#"{"required_functions": ["get_weather"], "clarification": "What location?"}"#;
let handler = ArchFunctionHandler::new(
"test-model".to_string(),
ArchFunctionConfig::default(),
"http://localhost:8000".to_string(),
);
let content =
r#"{"required_functions": ["get_weather"], "clarification": "What location?"}"#;
let result = handler.parse_model_response(content);
assert!(result.is_valid);
@ -1424,7 +1539,11 @@ mod tests {
#[test]
fn test_convert_data_type_int_to_float() {
let handler = ArchFunctionHandler::new("test-model".to_string(), ArchFunctionConfig::default(), "http://localhost:8000".to_string());
let handler = ArchFunctionHandler::new(
"test-model".to_string(),
ArchFunctionConfig::default(),
"http://localhost:8000".to_string(),
);
let value = json!(42);
let result = handler.convert_data_type(&value, "float");
assert!(result.is_ok());
@ -1504,13 +1623,12 @@ pub fn check_threshold(
}
/// Checks if a parameter is required in the function description
pub fn is_parameter_required(
function_description: &Value,
parameter_name: &str,
) -> bool {
pub fn is_parameter_required(function_description: &Value, parameter_name: &str) -> bool {
if let Some(required) = function_description.get("required") {
if let Some(required_arr) = required.as_array() {
return required_arr.iter().any(|v| v.as_str() == Some(parameter_name));
return required_arr
.iter()
.any(|v| v.as_str() == Some(parameter_name));
}
}
false
@ -1559,12 +1677,7 @@ impl HallucinationState {
pub fn new(functions: &[Tool]) -> Self {
let function_properties: HashMap<String, Value> = functions
.iter()
.map(|tool| {
(
tool.function.name.clone(),
tool.function.parameters.clone(),
)
})
.map(|tool| (tool.function.name.clone(), tool.function.parameters.clone()))
.collect();
Self {
@ -1620,7 +1733,10 @@ impl HallucinationState {
// Function name extraction logic
if self.state.as_deref() == Some("function_name") {
if !FUNC_NAME_END_TOKEN.iter().any(|&t| self.tokens.last().map_or(false, |tok| tok == t)) {
if !FUNC_NAME_END_TOKEN
.iter()
.any(|&t| self.tokens.last().is_some_and(|tok| tok == t))
{
self.mask.push(MaskToken::FunctionName);
} else {
self.state = None;
@ -1629,34 +1745,51 @@ impl HallucinationState {
}
// Check for function name start
if FUNC_NAME_START_PATTERN.iter().any(|&p| content.ends_with(p)) {
if FUNC_NAME_START_PATTERN
.iter()
.any(|&p| content.ends_with(p))
{
self.state = Some("function_name".to_string());
}
// Parameter name extraction logic
if self.state.as_deref() == Some("parameter_name")
&& !PARAMETER_NAME_END_TOKENS.iter().any(|&t| content.ends_with(t)) {
&& !PARAMETER_NAME_END_TOKENS
.iter()
.any(|&t| content.ends_with(t))
{
self.mask.push(MaskToken::ParameterName);
} else if self.state.as_deref() == Some("parameter_name")
&& PARAMETER_NAME_END_TOKENS.iter().any(|&t| content.ends_with(t)) {
&& PARAMETER_NAME_END_TOKENS
.iter()
.any(|&t| content.ends_with(t))
{
self.state = None;
self.parameter_name_done = true;
self.get_parameter_name();
} else if self.parameter_name_done
&& !self.open_bracket
&& PARAMETER_NAME_START_PATTERN.iter().any(|&p| content.ends_with(p)) {
&& PARAMETER_NAME_START_PATTERN
.iter()
.any(|&p| content.ends_with(p))
{
self.state = Some("parameter_name".to_string());
}
// First parameter value start
if FIRST_PARAM_NAME_START_PATTERN.iter().any(|&p| content.ends_with(p)) {
if FIRST_PARAM_NAME_START_PATTERN
.iter()
.any(|&p| content.ends_with(p))
{
self.state = Some("parameter_name".to_string());
}
// Parameter value extraction logic
if self.state.as_deref() == Some("parameter_value")
&& !PARAMETER_VALUE_END_TOKEN.iter().any(|&t| content.ends_with(t)) {
&& !PARAMETER_VALUE_END_TOKEN
.iter()
.any(|&t| content.ends_with(t))
{
// Check for brackets
if let Some(last_token) = self.tokens.last() {
let open_brackets: Vec<char> = last_token
@ -1694,8 +1827,11 @@ impl HallucinationState {
&& self.mask[self.mask.len() - 2] != MaskToken::ParameterValue
&& !self.parameter_name.is_empty()
{
let last_param = self.parameter_name[self.parameter_name.len() - 1].clone();
if let Some(func_props) = self.function_properties.get(&self.function_name) {
let last_param =
self.parameter_name[self.parameter_name.len() - 1].clone();
if let Some(func_props) =
self.function_properties.get(&self.function_name)
{
if is_parameter_required(func_props, &last_param)
&& !is_parameter_property(func_props, &last_param, "enum")
&& !self.check_parameter_name.contains_key(&last_param)
@ -1718,10 +1854,16 @@ impl HallucinationState {
}
} else if self.state.as_deref() == Some("parameter_value")
&& !self.open_bracket
&& PARAMETER_VALUE_END_TOKEN.iter().any(|&t| content.ends_with(t)) {
&& PARAMETER_VALUE_END_TOKEN
.iter()
.any(|&t| content.ends_with(t))
{
self.state = None;
} else if self.parameter_name_done
&& PARAMETER_VALUE_START_PATTERN.iter().any(|&p| content.ends_with(p)) {
&& PARAMETER_VALUE_START_PATTERN
.iter()
.any(|&p| content.ends_with(p))
{
self.state = Some("parameter_value".to_string());
}
@ -1848,18 +1990,18 @@ mod hallucination_tests {
let handler = ArchFunctionHandler::new(
"test-model".to_string(),
ArchFunctionConfig::default(),
"http://localhost:8000".to_string()
"http://localhost:8000".to_string(),
);
// Test integer types
assert!(handler.check_value_type(&json!(42), "integer"));
assert!(handler.check_value_type(&json!(42), "int"));
assert!(!handler.check_value_type(&json!(3.14), "integer"));
assert!(!handler.check_value_type(&json!(3.15), "integer"));
// Test number types (accepts both int and float)
assert!(handler.check_value_type(&json!(3.14), "number"));
assert!(handler.check_value_type(&json!(3.15), "number"));
assert!(handler.check_value_type(&json!(42), "number"));
assert!(handler.check_value_type(&json!(3.14), "float"));
assert!(handler.check_value_type(&json!(3.15), "float"));
// Test boolean
assert!(handler.check_value_type(&json!(true), "boolean"));
@ -1890,12 +2032,16 @@ mod hallucination_tests {
let handler = ArchFunctionHandler::new(
"test-model".to_string(),
ArchFunctionConfig::default(),
"http://localhost:8000".to_string()
"http://localhost:8000".to_string(),
);
// Test valid type - no conversion needed
assert!(handler.validate_or_convert_parameter(&json!(42), "integer").unwrap());
assert!(handler.validate_or_convert_parameter(&json!("hello"), "string").unwrap());
assert!(handler
.validate_or_convert_parameter(&json!(42), "integer")
.unwrap());
assert!(handler
.validate_or_convert_parameter(&json!("hello"), "string")
.unwrap());
// Test integer to float conversion (convert_data_type supports this)
let result = handler.validate_or_convert_parameter(&json!(42), "float");
@ -1910,8 +2056,12 @@ mod hallucination_tests {
assert!(!result.unwrap());
// Test number accepting both int and float
assert!(handler.validate_or_convert_parameter(&json!(42), "number").unwrap());
assert!(handler.validate_or_convert_parameter(&json!(3.14), "number").unwrap());
assert!(handler
.validate_or_convert_parameter(&json!(42), "number")
.unwrap());
assert!(handler
.validate_or_convert_parameter(&json!(3.15), "number")
.unwrap());
}
#[test]

View file

@ -14,7 +14,7 @@ use crate::router::plano_orchestrator::OrchestratorService;
/// 2. PipelineProcessor - executes the agent pipeline
/// 3. ResponseHandler - handles response streaming
#[cfg(test)]
mod integration_tests {
mod tests {
use super::*;
use common::configuration::{Agent, AgentFilterChain, Listener};
@ -62,7 +62,10 @@ mod integration_tests {
let agent_pipeline = AgentFilterChain {
id: "terminal-agent".to_string(),
filter_chain: Some(vec!["filter-agent".to_string(), "terminal-agent".to_string()]),
filter_chain: Some(vec![
"filter-agent".to_string(),
"terminal-agent".to_string(),
]),
description: Some("Test pipeline".to_string()),
default: Some(true),
};

View file

@ -2,48 +2,48 @@ use serde::{Deserialize, Serialize};
use std::collections::HashMap;
pub const JSON_RPC_VERSION: &str = "2.0";
pub const TOOL_CALL_METHOD : &str = "tools/call";
pub const TOOL_CALL_METHOD: &str = "tools/call";
pub const MCP_INITIALIZE: &str = "initialize";
pub const MCP_INITIALIZE_NOTIFICATION: &str = "notifications/initialized";
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum JsonRpcId {
String(String),
Number(u64),
String(String),
Number(u64),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JsonRpcRequest {
pub jsonrpc: String,
pub id: JsonRpcId,
pub method: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub params: Option<HashMap<String, serde_json::Value>>,
pub jsonrpc: String,
pub id: JsonRpcId,
pub method: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub params: Option<HashMap<String, serde_json::Value>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JsonRpcNotification {
pub jsonrpc: String,
pub method: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub params: Option<HashMap<String, serde_json::Value>>,
pub jsonrpc: String,
pub method: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub params: Option<HashMap<String, serde_json::Value>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JsonRpcError {
pub code: i32,
pub message: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub data: Option<serde_json::Value>,
pub code: i32,
pub message: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub data: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JsonRpcResponse {
pub jsonrpc: String,
pub id: JsonRpcId,
#[serde(skip_serializing_if = "Option::is_none")]
pub result: Option<HashMap<String, serde_json::Value>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub error: Option<JsonRpcError>,
pub jsonrpc: String,
pub id: JsonRpcId,
#[serde(skip_serializing_if = "Option::is_none")]
pub result: Option<HashMap<String, serde_json::Value>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub error: Option<JsonRpcError>,
}

View file

@ -1,6 +1,8 @@
use bytes::Bytes;
use common::configuration::{LlmProvider, ModelAlias};
use common::consts::{ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER, REQUEST_ID_HEADER, TRACE_PARENT_HEADER};
use common::consts::{
ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER, REQUEST_ID_HEADER, TRACE_PARENT_HEADER,
};
use common::traces::TraceCollector;
use hermesllm::apis::openai_responses::InputParam;
use hermesllm::clients::{SupportedAPIsFromClient, SupportedUpstreamAPIs};
@ -14,13 +16,14 @@ use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::{debug, info, warn};
use crate::router::llm_router::RouterService;
use crate::handlers::utils::{create_streaming_response, ObservableStreamProcessor, truncate_message};
use crate::handlers::router_chat::router_chat_get_upstream_model;
use crate::handlers::utils::{
create_streaming_response, truncate_message, ObservableStreamProcessor,
};
use crate::router::llm_router::RouterService;
use crate::state::response_state_processor::ResponsesStateProcessor;
use crate::state::{
StateStorage, StateStorageError,
extract_input_items, retrieve_and_combine_input
extract_input_items, retrieve_and_combine_input, StateStorage, StateStorageError,
};
use crate::tracing::operation_component;
@ -39,7 +42,6 @@ pub async fn llm_chat(
trace_collector: Arc<TraceCollector>,
state_storage: Option<Arc<dyn StateStorage>>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
let request_path = request.uri().path().to_string();
let request_headers = request.headers().clone();
let request_id = request_headers
@ -74,8 +76,14 @@ pub async fn llm_chat(
)) {
Ok(request) => request,
Err(err) => {
warn!("[PLANO_REQ_ID:{}] | FAILURE | Failed to parse request as ProviderRequestType: {}", request_id, err);
let err_msg = format!("[PLANO_REQ_ID:{}] | FAILURE | Failed to parse request: {}", request_id, err);
warn!(
"[PLANO_REQ_ID:{}] | FAILURE | Failed to parse request as ProviderRequestType: {}",
request_id, err
);
let err_msg = format!(
"[PLANO_REQ_ID:{}] | FAILURE | Failed to parse request: {}",
request_id, err
);
let mut bad_request = Response::new(full(err_msg));
*bad_request.status_mut() = StatusCode::BAD_REQUEST;
return Ok(bad_request);
@ -85,7 +93,10 @@ pub async fn llm_chat(
// === v1/responses state management: Extract input items early ===
let mut original_input_items = Vec::new();
let client_api = SupportedAPIsFromClient::from_endpoint(request_path.as_str());
let is_responses_api_client = matches!(client_api, Some(SupportedAPIsFromClient::OpenAIResponsesAPI(_)));
let is_responses_api_client = matches!(
client_api,
Some(SupportedAPIsFromClient::OpenAIResponsesAPI(_))
);
// Model alias resolution: update model field in client_request immediately
// This ensures all downstream objects use the resolved model
@ -96,20 +107,28 @@ pub async fn llm_chat(
// Extract tool names and user message preview for span attributes
let tool_names = client_request.get_tool_names();
let user_message_preview = client_request.get_recent_user_message()
let user_message_preview = client_request
.get_recent_user_message()
.map(|msg| truncate_message(&msg, 50));
client_request.set_model(resolved_model.clone());
if client_request.remove_metadata_key("archgw_preference_config") {
debug!("[PLANO_REQ_ID:{}] Removed archgw_preference_config from metadata", request_id);
debug!(
"[PLANO_REQ_ID:{}] Removed archgw_preference_config from metadata",
request_id
);
}
// === v1/responses state management: Determine upstream API and combine input if needed ===
// Do this BEFORE routing since routing consumes the request
// Only process state if state_storage is configured
let mut should_manage_state = false;
if is_responses_api_client && state_storage.is_some() {
if let ProviderRequestType::ResponsesAPIRequest(ref mut responses_req) = client_request {
if is_responses_api_client {
if let (
ProviderRequestType::ResponsesAPIRequest(ref mut responses_req),
Some(ref state_store),
) = (&mut client_request, &state_storage)
{
// Extract original input once
original_input_items = extract_input_items(&responses_req.input);
@ -120,18 +139,22 @@ pub async fn llm_chat(
&request_path,
&resolved_model,
is_streaming_request,
).await;
)
.await;
let upstream_api = SupportedUpstreamAPIs::from_endpoint(&upstream_path);
// Only manage state if upstream is NOT OpenAIResponsesAPI (needs translation)
should_manage_state = !matches!(upstream_api, Some(SupportedUpstreamAPIs::OpenAIResponsesAPI(_)));
should_manage_state = !matches!(
upstream_api,
Some(SupportedUpstreamAPIs::OpenAIResponsesAPI(_))
);
if should_manage_state {
// Retrieve and combine conversation history if previous_response_id exists
if let Some(ref prev_resp_id) = responses_req.previous_response_id {
match retrieve_and_combine_input(
state_storage.as_ref().unwrap().clone(),
state_store.clone(),
prev_resp_id,
original_input_items, // Pass ownership instead of cloning
)
@ -166,7 +189,10 @@ pub async fn llm_chat(
}
}
} else {
debug!("[PLANO_REQ_ID:{}] | BRIGHT_STAFF | Upstream supports ResponsesAPI natively.", request_id);
debug!(
"[PLANO_REQ_ID:{}] | BRIGHT_STAFF | Upstream supports ResponsesAPI natively.",
request_id
);
}
}
}
@ -177,7 +203,7 @@ pub async fn llm_chat(
// Determine routing using the dedicated router_chat module
let routing_result = match router_chat_get_upstream_model(
router_service,
client_request, // Pass the original request - router_chat will convert it
client_request, // Pass the original request - router_chat will convert it
&request_headers,
trace_collector.clone(),
&traceparent,
@ -257,7 +283,8 @@ pub async fn llm_chat(
user_message_preview,
temperature,
&llm_providers,
).await;
)
.await;
// Create base processor for metrics and tracing
let base_processor = ObservableStreamProcessor::new(
@ -269,7 +296,11 @@ pub async fn llm_chat(
// === v1/responses state management: Wrap with ResponsesStateProcessor ===
// Only wrap if we need to manage state (client is ResponsesAPI AND upstream is NOT ResponsesAPI AND state_storage is configured)
let streaming_response = if should_manage_state && !original_input_items.is_empty() && state_storage.is_some() {
let streaming_response = if let (true, false, Some(state_store)) = (
should_manage_state,
original_input_items.is_empty(),
state_storage,
) {
// Extract Content-Encoding header to handle decompression for state parsing
let content_encoding = response_headers
.get("content-encoding")
@ -279,7 +310,7 @@ pub async fn llm_chat(
// Wrap with state management processor to store state after response completes
let state_processor = ResponsesStateProcessor::new(
base_processor,
state_storage.unwrap(),
state_store,
original_input_items,
resolved_model.clone(),
model_name.clone(),
@ -324,6 +355,7 @@ fn resolve_model_alias(
}
/// Builds the LLM span with all required and optional attributes.
#[allow(clippy::too_many_arguments)]
async fn build_llm_span(
traceparent: &str,
request_path: &str,
@ -337,8 +369,8 @@ async fn build_llm_span(
temperature: Option<f32>,
llm_providers: &Arc<RwLock<Vec<LlmProvider>>>,
) -> common::traces::Span {
use common::traces::{SpanBuilder, SpanKind, parse_traceparent};
use crate::tracing::{http, llm, OperationNameBuilder};
use common::traces::{parse_traceparent, SpanBuilder, SpanKind};
// Calculate the upstream path based on provider configuration
let upstream_path = get_upstream_path(
@ -347,13 +379,14 @@ async fn build_llm_span(
request_path,
resolved_model,
is_streaming,
).await;
)
.await;
// Build operation name showing path transformation if different
let operation_name = if request_path != upstream_path {
OperationNameBuilder::new()
.with_method("POST")
.with_path(&format!("{} >> {}", request_path, upstream_path))
.with_path(format!("{} >> {}", request_path, upstream_path))
.with_target(resolved_model)
.build()
} else {
@ -388,7 +421,8 @@ async fn build_llm_span(
}
if let Some(tools) = tool_names {
let formatted_tools = tools.iter()
let formatted_tools = tools
.iter()
.map(|name| format!("{}(...)", name))
.collect::<Vec<_>>()
.join("\n");
@ -436,8 +470,7 @@ async fn get_provider_info(
// First, try to find by model name or provider name
let provider = providers_lock.iter().find(|p| {
p.model.as_ref().map(|m| m == model_name).unwrap_or(false)
|| p.name == model_name
p.model.as_ref().map(|m| m == model_name).unwrap_or(false) || p.name == model_name
});
if let Some(provider) = provider {
@ -446,9 +479,7 @@ async fn get_provider_info(
return (provider_id, prefix);
}
let default_provider = providers_lock.iter().find(|p| {
p.default.unwrap_or(false)
});
let default_provider = providers_lock.iter().find(|p| p.default.unwrap_or(false));
if let Some(provider) = default_provider {
let provider_id = provider.provider_interface.to_provider_id();

View file

@ -1,13 +1,13 @@
pub mod agent_chat_completions;
pub mod agent_selector;
pub mod llm;
pub mod router_chat;
pub mod models;
pub mod function_calling;
pub mod jsonrpc;
pub mod llm;
pub mod models;
pub mod pipeline_processor;
pub mod response_handler;
pub mod router_chat;
pub mod utils;
pub mod jsonrpc;
#[cfg(test)]
mod integration_tests;

View file

@ -82,6 +82,7 @@ impl PipelineProcessor {
}
/// Record a span for filter execution
#[allow(clippy::too_many_arguments)]
fn record_filter_span(
&self,
collector: &std::sync::Arc<common::traces::TraceCollector>,
@ -132,6 +133,7 @@ impl PipelineProcessor {
}
/// Record a span for MCP protocol interactions
#[allow(clippy::too_many_arguments)]
fn record_agent_filter_span(
&self,
collector: &std::sync::Arc<common::traces::TraceCollector>,
@ -156,12 +158,12 @@ impl PipelineProcessor {
.build();
let mut span_builder = SpanBuilder::new(&operation_name)
.with_span_id(span_id.unwrap_or_else(|| generate_random_span_id()))
.with_span_id(span_id.unwrap_or_else(generate_random_span_id))
.with_kind(SpanKind::Client)
.with_start_time(start_time)
.with_end_time(end_time)
.with_attribute(http::METHOD, "POST")
.with_attribute(http::TARGET, &format!("/mcp ({})", operation.to_string()))
.with_attribute(http::TARGET, format!("/mcp ({})", operation))
.with_attribute("mcp.operation", operation.to_string())
.with_attribute("mcp.agent_id", agent_id.to_string())
.with_attribute(
@ -188,6 +190,7 @@ impl PipelineProcessor {
}
/// Process the filter chain of agents (all except the terminal agent)
#[allow(clippy::too_many_arguments)]
pub async fn process_filter_chain(
&mut self,
chat_history: &[Message],
@ -1023,7 +1026,7 @@ mod tests {
}
});
let sse_body = format!("event: message\ndata: {}\n\n", rpc_body.to_string());
let sse_body = format!("event: message\ndata: {}\n\n", rpc_body);
let mut server = Server::new_async().await;
let _m = server
@ -1061,10 +1064,10 @@ mod tests {
.await;
match result {
Err(PipelineError::ClientError { status, body, .. }) => {
assert_eq!(status, 400);
assert_eq!(body, "bad tool call");
}
Err(PipelineError::ClientError { status, body, .. }) => {
assert_eq!(status, 400);
assert_eq!(body, "bad tool call");
}
_ => panic!("Expected client error when isError flag is set"),
}
}

View file

@ -133,9 +133,7 @@ impl ResponseHandler {
let response_headers = llm_response.headers();
let is_sse_streaming = response_headers
.get(hyper::header::CONTENT_TYPE)
.map_or(false, |v| {
v.to_str().unwrap_or("").contains("text/event-stream")
});
.is_some_and(|v| v.to_str().unwrap_or("").contains("text/event-stream"));
let response_bytes = llm_response
.bytes()
@ -164,7 +162,7 @@ impl ResponseHandler {
match transformed_event.provider_response() {
Ok(provider_response) => {
if let Some(content) = provider_response.content_delta() {
accumulated_text.push_str(&content);
accumulated_text.push_str(content);
} else {
info!("No content delta in provider response");
}
@ -174,7 +172,7 @@ impl ResponseHandler {
}
}
}
return Ok(accumulated_text);
Ok(accumulated_text)
} else {
// If not SSE, treat as regular text response
let response_text = String::from_utf8(response_bytes.to_vec()).map_err(|e| {

View file

@ -1,6 +1,6 @@
use common::configuration::ModelUsagePreference;
use common::consts::{REQUEST_ID_HEADER};
use common::traces::{TraceCollector, SpanKind, SpanBuilder, parse_traceparent};
use common::consts::REQUEST_ID_HEADER;
use common::traces::{parse_traceparent, SpanBuilder, SpanKind, TraceCollector};
use hermesllm::clients::endpoints::SupportedUpstreamAPIs;
use hermesllm::{ProviderRequest, ProviderRequestType};
use hyper::StatusCode;
@ -9,10 +9,10 @@ use std::sync::Arc;
use tracing::{debug, info, warn};
use crate::router::llm_router::RouterService;
use crate::tracing::{OperationNameBuilder, operation_component, http, routing};
use crate::tracing::{http, operation_component, routing, OperationNameBuilder};
pub struct RoutingResult {
pub model_name: String
pub model_name: String,
}
pub struct RoutingError {
@ -24,7 +24,7 @@ impl RoutingError {
pub fn internal_error(message: String) -> Self {
Self {
message,
status_code: StatusCode::INTERNAL_SERVER_ERROR
status_code: StatusCode::INTERNAL_SERVER_ERROR,
}
}
}
@ -52,9 +52,7 @@ pub async fn router_chat_get_upstream_model(
// Convert to ChatCompletionsRequest for routing (regardless of input type)
let chat_request = match ProviderRequestType::try_from((
client_request,
&SupportedUpstreamAPIs::OpenAIChatCompletions(
hermesllm::apis::OpenAIApi::ChatCompletions,
),
&SupportedUpstreamAPIs::OpenAIChatCompletions(hermesllm::apis::OpenAIApi::ChatCompletions),
)) {
Ok(ProviderRequestType::ChatCompletionsRequest(req)) => req,
Ok(
@ -69,7 +67,10 @@ pub async fn router_chat_get_upstream_model(
));
}
Err(err) => {
warn!("Failed to convert request to ChatCompletionsRequest: {}", err);
warn!(
"Failed to convert request to ChatCompletionsRequest: {}",
err
);
return Err(RoutingError::internal_error(format!(
"Failed to convert request: {}",
err
@ -151,9 +152,7 @@ pub async fn router_chat_get_upstream_model(
)
.await;
Ok(RoutingResult {
model_name
})
Ok(RoutingResult { model_name })
}
None => {
// No route determined, use default model from request
@ -176,7 +175,7 @@ pub async fn router_chat_get_upstream_model(
.await;
Ok(RoutingResult {
model_name: default_model
model_name: default_model,
})
}
},
@ -194,9 +193,10 @@ pub async fn router_chat_get_upstream_model(
)
.await;
Err(RoutingError::internal_error(
format!("Failed to determine route: {}", err)
))
Err(RoutingError::internal_error(format!(
"Failed to determine route: {}",
err
)))
}
}
}
@ -230,7 +230,10 @@ async fn record_routing_span(
.with_end_time(std::time::SystemTime::now())
.with_attribute(http::METHOD, "POST")
.with_attribute(http::TARGET, routing_api_path.to_string())
.with_attribute(routing::ROUTE_DETERMINATION_MS, start_time.elapsed().as_millis().to_string());
.with_attribute(
routing::ROUTE_DETERMINATION_MS,
start_time.elapsed().as_millis().to_string(),
);
// Only set parent span ID if it exists (not a root span)
if let Some(parent) = parent_span_id {

View file

@ -1,5 +1,5 @@
use bytes::Bytes;
use common::traces::{Span, Attribute, AttributeValue, TraceCollector, Event};
use common::traces::{Attribute, AttributeValue, Event, Span, TraceCollector};
use http_body_util::combinators::BoxBody;
use http_body_util::StreamBody;
use hyper::body::Frame;
@ -11,7 +11,7 @@ use tokio_stream::StreamExt;
use tracing::warn;
// Import tracing constants
use crate::tracing::{llm, error};
use crate::tracing::{error, llm};
/// Trait for processing streaming chunks
/// Implementors can inject custom logic during streaming (e.g., hallucination detection, logging)
@ -97,7 +97,6 @@ impl StreamProcessor for ObservableStreamProcessor {
},
});
self.span.attributes.push(Attribute {
key: llm::DURATION_MS.to_string(),
value: AttributeValue {
@ -119,11 +118,9 @@ impl StreamProcessor for ObservableStreamProcessor {
if let Ok(start_time_nanos) = self.span.start_time_unix_nano.parse::<u128>() {
// Convert ttft from milliseconds to nanoseconds and add to start time
let event_timestamp = start_time_nanos + (ttft * 1_000_000);
let mut event = Event::new(llm::TIME_TO_FIRST_TOKEN_MS.to_string(), event_timestamp);
event.add_attribute(
llm::TIME_TO_FIRST_TOKEN_MS.to_string(),
ttft.to_string(),
);
let mut event =
Event::new(llm::TIME_TO_FIRST_TOKEN_MS.to_string(), event_timestamp);
event.add_attribute(llm::TIME_TO_FIRST_TOKEN_MS.to_string(), ttft.to_string());
// Initialize events vector if needed
if self.span.events.is_none() {
@ -137,7 +134,8 @@ impl StreamProcessor for ObservableStreamProcessor {
}
// Record the finalized span
self.collector.record_span(&self.service_name, self.span.clone());
self.collector
.record_span(&self.service_name, self.span.clone());
}
fn on_error(&mut self, error_msg: &str) {
@ -173,7 +171,8 @@ impl StreamProcessor for ObservableStreamProcessor {
});
// Record the error span
self.collector.record_span(&self.service_name, self.span.clone());
self.collector
.record_span(&self.service_name, self.span.clone());
}
}