cargo clippy (#660)

This commit is contained in:
Adil Hafeez 2025-12-25 21:08:37 -08:00 committed by GitHub
parent c75e7606f9
commit ca95ffb63d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
62 changed files with 1864 additions and 1187 deletions

View file

@ -3,7 +3,7 @@ use std::time::{Instant, SystemTime};
use bytes::Bytes;
use common::consts::TRACE_PARENT_HEADER;
use common::traces::{SpanBuilder, SpanKind, parse_traceparent, generate_random_span_id};
use common::traces::{generate_random_span_id, parse_traceparent, SpanBuilder, SpanKind};
use hermesllm::apis::OpenAIMessage;
use hermesllm::clients::SupportedAPIsFromClient;
use hermesllm::providers::request::ProviderRequest;
@ -18,7 +18,7 @@ use super::agent_selector::{AgentSelectionError, AgentSelector};
use super::pipeline_processor::{PipelineError, PipelineProcessor};
use super::response_handler::ResponseHandler;
use crate::router::plano_orchestrator::OrchestratorService;
use crate::tracing::{OperationNameBuilder, operation_component, http};
use crate::tracing::{http, operation_component, OperationNameBuilder};
/// Main errors for agent chat completions
#[derive(Debug, thiserror::Error)]
@ -61,7 +61,6 @@ pub async fn agent_chat(
body,
}) = &err
{
warn!(
"Client error from agent '{}' (HTTP {}): {}",
agent, status, body
@ -77,8 +76,8 @@ pub async fn agent_chat(
let json_string = error_json.to_string();
let mut response = Response::new(ResponseHandler::create_full_body(json_string));
*response.status_mut() = hyper::StatusCode::from_u16(*status)
.unwrap_or(hyper::StatusCode::BAD_REQUEST);
*response.status_mut() =
hyper::StatusCode::from_u16(*status).unwrap_or(hyper::StatusCode::BAD_REQUEST);
response.headers_mut().insert(
hyper::header::CONTENT_TYPE,
"application/json".parse().unwrap(),
@ -234,8 +233,18 @@ async fn handle_agent_chat(
.with_attribute(http::TARGET, "/agents/select")
.with_attribute("selection.listener", listener.name.clone())
.with_attribute("selection.agent_count", selected_agents.len().to_string())
.with_attribute("selection.agents", selected_agents.iter().map(|a| a.id.as_str()).collect::<Vec<_>>().join(","))
.with_attribute("duration_ms", format!("{:.2}", selection_elapsed.as_secs_f64() * 1000.0));
.with_attribute(
"selection.agents",
selected_agents
.iter()
.map(|a| a.id.as_str())
.collect::<Vec<_>>()
.join(","),
)
.with_attribute(
"duration_ms",
format!("{:.2}", selection_elapsed.as_secs_f64() * 1000.0),
);
if !trace_id.is_empty() {
selection_span_builder = selection_span_builder.with_trace_id(trace_id.clone());
@ -318,8 +327,14 @@ async fn handle_agent_chat(
.with_attribute(http::METHOD, "POST")
.with_attribute(http::TARGET, full_path)
.with_attribute("agent.name", agent_name.clone())
.with_attribute("agent.sequence", format!("{}/{}", agent_index + 1, agent_count))
.with_attribute("duration_ms", format!("{:.2}", agent_elapsed.as_secs_f64() * 1000.0));
.with_attribute(
"agent.sequence",
format!("{}/{}", agent_index + 1, agent_count),
)
.with_attribute(
"duration_ms",
format!("{:.2}", agent_elapsed.as_secs_f64() * 1000.0),
);
if !trace_id.is_empty() {
span_builder = span_builder.with_trace_id(trace_id.clone());
@ -333,7 +348,10 @@ async fn handle_agent_chat(
// If this is the last agent, return the streaming response
if is_last_agent {
info!("Completed agent chain, returning response from last agent: {}", agent_name);
info!(
"Completed agent chain, returning response from last agent: {}",
agent_name
);
return response_handler
.create_streaming_response(llm_response)
.await
@ -341,7 +359,10 @@ async fn handle_agent_chat(
}
// For intermediate agents, collect the full response and pass to next agent
debug!("Collecting response from intermediate agent: {}", agent_name);
debug!(
"Collecting response from intermediate agent: {}",
agent_name
);
let response_text = response_handler.collect_full_response(llm_response).await?;
info!(
@ -364,7 +385,6 @@ async fn handle_agent_chat(
});
current_messages.push(last_message);
}
// This should never be reached since we return in the last agent iteration

View file

@ -2,7 +2,7 @@ use std::collections::HashMap;
use std::sync::Arc;
use common::configuration::{
Agent, AgentFilterChain, Listener, AgentUsagePreference, OrchestrationPreference,
Agent, AgentFilterChain, AgentUsagePreference, Listener, OrchestrationPreference,
};
use hermesllm::apis::openai::Message;
use tracing::{debug, warn};

View file

@ -1,20 +1,18 @@
use bytes::Bytes;
use eventsource_stream::Eventsource;
use futures::StreamExt;
use hermesllm::apis::openai::{
ChatCompletionsRequest, ChatCompletionsResponse, Choice, FinishReason, FunctionCall, Message,
MessageContent, ResponseMessage, Role, Tool, ToolCall, Usage,
};
use http_body_util::{combinators::BoxBody, BodyExt, Full};
use hyper::body::Incoming;
use hyper::{Request, Response, StatusCode};
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use std::collections::HashMap;
use thiserror::Error;
use tracing::{info, error};
use futures::StreamExt;
use bytes::Bytes;
use http_body_util::{combinators::BoxBody, BodyExt, Full};
use hyper::body::Incoming;
use hyper::{Request, Response, StatusCode};
use eventsource_stream::Eventsource;
use tracing::{error, info};
// ============================================================================
// CONSTANTS FOR HALLUCINATION DETECTION
@ -273,17 +271,14 @@ impl ArchFunctionHandler {
let mut stack: Vec<char> = Vec::new();
let mut fixed_str = String::new();
let matching_bracket: HashMap<char, char> =
[(')', '('), ('}', '{'), (']', '[')]
.iter()
.cloned()
.collect();
let opening_bracket: HashMap<char, char> = matching_bracket
let matching_bracket: HashMap<char, char> = [(')', '('), ('}', '{'), (']', '[')]
.iter()
.map(|(k, v)| (*v, *k))
.cloned()
.collect();
let opening_bracket: HashMap<char, char> =
matching_bracket.iter().map(|(k, v)| (*v, *k)).collect();
for ch in json_str.chars() {
if ch == '{' || ch == '[' || ch == '(' {
stack.push(ch);
@ -332,12 +327,18 @@ impl ArchFunctionHandler {
// Remove markdown code blocks
let mut content = content.trim().to_string();
if content.starts_with("```") && content.ends_with("```") {
content = content.trim_start_matches("```").trim_end_matches("```").to_string();
content = content
.trim_start_matches("```")
.trim_end_matches("```")
.to_string();
if content.starts_with("json") {
content = content.trim_start_matches("json").to_string();
}
// Trim again after removing code blocks to eliminate internal whitespace
content = content.trim_start_matches(r"\n").trim_end_matches(r"\n").to_string();
content = content
.trim_start_matches(r"\n")
.trim_end_matches(r"\n")
.to_string();
content = content.trim().to_string();
// Unescape the quotes: \" -> "
// The model sometimes returns escaped JSON inside markdown blocks
@ -453,12 +454,12 @@ impl ArchFunctionHandler {
/// Helper method to check if a value matches the expected type
fn check_value_type(&self, value: &Value, target_type: &str) -> bool {
match target_type {
"int" | "integer" => value.is_i64() || value.is_u64(),
"int" | "integer" => value.is_i64() || value.is_u64(),
"float" | "number" => value.is_f64() || value.is_i64() || value.is_u64(),
"bool" | "boolean" => value.is_boolean(),
"str" | "string" => value.is_string(),
"list" | "array" => value.is_array(),
"dict" | "object" => value.is_object(),
"bool" | "boolean" => value.is_boolean(),
"str" | "string" => value.is_string(),
"list" | "array" => value.is_array(),
"dict" | "object" => value.is_object(),
_ => true,
}
}
@ -505,15 +506,19 @@ impl ArchFunctionHandler {
let func_name = &tool_call.function.name;
// Parse arguments as JSON
let func_args: HashMap<String, Value> = match serde_json::from_str(&tool_call.function.arguments) {
Ok(args) => args,
Err(e) => {
verification.is_valid = false;
verification.invalid_tool_call = Some(tool_call.clone());
verification.error_message = format!("Failed to parse arguments for function '{}': {}", func_name, e);
break;
}
};
let func_args: HashMap<String, Value> =
match serde_json::from_str(&tool_call.function.arguments) {
Ok(args) => args,
Err(e) => {
verification.is_valid = false;
verification.invalid_tool_call = Some(tool_call.clone());
verification.error_message = format!(
"Failed to parse arguments for function '{}': {}",
func_name, e
);
break;
}
};
// Check if function is available
if let Some(function_params) = functions.get(func_name) {
@ -541,14 +546,23 @@ impl ArchFunctionHandler {
if let Some(properties_obj) = properties.as_object() {
for (param_name, param_value) in &func_args {
if let Some(param_schema) = properties_obj.get(param_name) {
if let Some(target_type) = param_schema.get("type").and_then(|v| v.as_str()) {
if self.config.support_data_types.contains(&target_type.to_string()) {
if let Some(target_type) =
param_schema.get("type").and_then(|v| v.as_str())
{
if self
.config
.support_data_types
.contains(&target_type.to_string())
{
// Validate data type using helper method
match self.validate_or_convert_parameter(param_value, target_type) {
match self
.validate_or_convert_parameter(param_value, target_type)
{
Ok(is_valid) => {
if !is_valid {
verification.is_valid = false;
verification.invalid_tool_call = Some(tool_call.clone());
verification.invalid_tool_call =
Some(tool_call.clone());
verification.error_message = format!(
"Parameter `{}` is expected to have the data type `{}`, got incompatible type.",
param_name, target_type
@ -558,7 +572,8 @@ impl ArchFunctionHandler {
}
Err(_) => {
verification.is_valid = false;
verification.invalid_tool_call = Some(tool_call.clone());
verification.invalid_tool_call =
Some(tool_call.clone());
verification.error_message = format!(
"Parameter `{}` is expected to have the data type `{}`, got incompatible type.",
param_name, target_type
@ -569,7 +584,10 @@ impl ArchFunctionHandler {
} else {
verification.is_valid = false;
verification.invalid_tool_call = Some(tool_call.clone());
verification.error_message = format!("Data type `{}` is not supported.", target_type);
verification.error_message = format!(
"Data type `{}` is not supported.",
target_type
);
break;
}
}
@ -598,11 +616,8 @@ impl ArchFunctionHandler {
/// Formats the system prompt with tools
pub fn format_system_prompt(&self, tools: &[Tool]) -> Result<String> {
let tools_str = self.convert_tools(tools)?;
let system_prompt = self
.config
.task_prompt
.replace("{tools}", &tools_str)
+ &self.config.format_prompt;
let system_prompt =
self.config.task_prompt.replace("{tools}", &tools_str) + &self.config.format_prompt;
Ok(system_prompt)
}
@ -665,15 +680,22 @@ impl ArchFunctionHandler {
// Strip markdown code blocks
if tool_call_msg.starts_with("```") && tool_call_msg.ends_with("```") {
tool_call_msg = tool_call_msg.trim_start_matches("```").trim_end_matches("```").trim().to_string();
tool_call_msg = tool_call_msg
.trim_start_matches("```")
.trim_end_matches("```")
.trim()
.to_string();
if tool_call_msg.starts_with("json") {
tool_call_msg = tool_call_msg.trim_start_matches("json").trim().to_string();
tool_call_msg =
tool_call_msg.trim_start_matches("json").trim().to_string();
}
}
// Extract function name
if let Ok(parsed) = serde_json::from_str::<Value>(&tool_call_msg) {
if let Some(tool_calls_arr) = parsed.get("tool_calls").and_then(|v| v.as_array()) {
if let Some(tool_calls_arr) =
parsed.get("tool_calls").and_then(|v| v.as_array())
{
if let Some(first_tool_call) = tool_calls_arr.first() {
let func_name = first_tool_call
.get("name")
@ -685,8 +707,10 @@ impl ArchFunctionHandler {
"result": content,
});
content = format!("<tool_response>\n{}\n</tool_response>",
serde_json::to_string(&tool_response)?);
content = format!(
"<tool_response>\n{}\n</tool_response>",
serde_json::to_string(&tool_response)?
);
}
}
}
@ -717,7 +741,7 @@ impl ArchFunctionHandler {
if let Some(instruction) = extra_instruction {
if let Some(last) = processed_messages.last_mut() {
if let MessageContent::Text(content) = &mut last.content {
content.push_str("\n");
content.push('\n');
content.push_str(instruction);
}
}
@ -750,13 +774,11 @@ impl ArchFunctionHandler {
for i in (conversation_idx..messages.len()).rev() {
if let MessageContent::Text(content) = &messages[i].content {
num_tokens += content.len() / 4;
if num_tokens >= max_tokens {
if messages[i].role == Role::User {
// Set message_idx to current position and break
// This matches Python's behavior where message_idx is set before break
message_idx = i;
break;
}
if num_tokens >= max_tokens && messages[i].role == Role::User {
// Set message_idx to current position and break
// This matches Python's behavior where message_idx is set before break
message_idx = i;
break;
}
}
// Only update message_idx if we haven't hit the token limit yet
@ -789,7 +811,11 @@ impl ArchFunctionHandler {
}
/// Helper to create a request with VLLM-specific parameters
fn create_request_with_extra_body(&self, messages: Vec<Message>, stream: bool) -> ChatCompletionsRequest {
fn create_request_with_extra_body(
&self,
messages: Vec<Message>,
stream: bool,
) -> ChatCompletionsRequest {
ChatCompletionsRequest {
model: self.model_name.clone(),
messages,
@ -813,24 +839,38 @@ impl ArchFunctionHandler {
}
/// Makes a streaming request and returns the SSE event stream
async fn make_streaming_request(&self, request: ChatCompletionsRequest) -> Result<std::pin::Pin<Box<dyn futures::Stream<Item = std::result::Result<Value, String>> + Send>>> {
let request_body = serde_json::to_string(&request)
.map_err(|e| FunctionCallingError::InvalidModelResponse(format!("Failed to serialize request: {}", e)))?;
async fn make_streaming_request(
&self,
request: ChatCompletionsRequest,
) -> Result<
std::pin::Pin<Box<dyn futures::Stream<Item = std::result::Result<Value, String>> + Send>>,
> {
let request_body = serde_json::to_string(&request).map_err(|e| {
FunctionCallingError::InvalidModelResponse(format!(
"Failed to serialize request: {}",
e
))
})?;
let response = self.http_client
let response = self
.http_client
.post(&self.endpoint_url)
.header("Content-Type", "application/json")
.body(request_body)
.send()
.await
.map_err(|e| FunctionCallingError::HttpError(e))?;
.map_err(FunctionCallingError::HttpError)?;
if !response.status().is_success() {
let status = response.status();
let error_text = response.text().await.unwrap_or_else(|_| "Unknown error".to_string());
return Err(FunctionCallingError::InvalidModelResponse(
format!("HTTP error {}: {}", status, error_text)
));
let error_text = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
return Err(FunctionCallingError::InvalidModelResponse(format!(
"HTTP error {}: {}",
status, error_text
)));
}
// Parse SSE stream
@ -856,38 +896,51 @@ impl ArchFunctionHandler {
}
/// Makes a non-streaming request and returns the response
async fn make_non_streaming_request(&self, request: ChatCompletionsRequest) -> Result<ChatCompletionsResponse> {
let request_body = serde_json::to_string(&request)
.map_err(|e| FunctionCallingError::InvalidModelResponse(format!("Failed to serialize request: {}", e)))?;
async fn make_non_streaming_request(
&self,
request: ChatCompletionsRequest,
) -> Result<ChatCompletionsResponse> {
let request_body = serde_json::to_string(&request).map_err(|e| {
FunctionCallingError::InvalidModelResponse(format!(
"Failed to serialize request: {}",
e
))
})?;
let response = self.http_client
let response = self
.http_client
.post(&self.endpoint_url)
.header("Content-Type", "application/json")
.body(request_body)
.send()
.await
.map_err(|e| FunctionCallingError::HttpError(e))?;
.map_err(FunctionCallingError::HttpError)?;
if !response.status().is_success() {
let status = response.status();
let error_text = response.text().await.unwrap_or_else(|_| "Unknown error".to_string());
return Err(FunctionCallingError::InvalidModelResponse(
format!("HTTP error {}: {}", status, error_text)
));
let error_text = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
return Err(FunctionCallingError::InvalidModelResponse(format!(
"HTTP error {}: {}",
status, error_text
)));
}
let response_text = response.text().await
.map_err(|e| FunctionCallingError::HttpError(e))?;
let response_text = response
.text()
.await
.map_err(FunctionCallingError::HttpError)?;
serde_json::from_str(&response_text)
.map_err(|e| FunctionCallingError::JsonParseError(e))
serde_json::from_str(&response_text).map_err(FunctionCallingError::JsonParseError)
}
pub async fn function_calling_chat(
&self,
request: ChatCompletionsRequest,
) -> Result<ChatCompletionsResponse> {
use tracing::{info, error};
use tracing::{error, info};
info!("[Arch-Function] - ChatCompletion");
@ -899,10 +952,14 @@ impl ArchFunctionHandler {
request.metadata.as_ref(),
)?;
info!("[request to arch-fc]: model: {}, messages count: {}",
self.model_name, messages.len());
info!(
"[request to arch-fc]: model: {}, messages count: {}",
self.model_name,
messages.len()
);
let use_agent_orchestrator = request.metadata
let use_agent_orchestrator = request
.metadata
.as_ref()
.and_then(|m| m.get("use_agent_orchestrator"))
.and_then(|v| v.as_bool())
@ -918,89 +975,95 @@ impl ArchFunctionHandler {
if use_agent_orchestrator {
while let Some(chunk_result) = stream.next().await {
let chunk = chunk_result.map_err(|e| FunctionCallingError::InvalidModelResponse(e))?;
let chunk = chunk_result.map_err(FunctionCallingError::InvalidModelResponse)?;
// Extract content from JSON response
if let Some(choices) = chunk.get("choices").and_then(|v| v.as_array()) {
if let Some(choice) = choices.first() {
if let Some(content) = choice.get("delta")
if let Some(content) = choice
.get("delta")
.and_then(|d| d.get("content"))
.and_then(|c| c.as_str()) {
.and_then(|c| c.as_str())
{
model_response.push_str(content);
}
}
}
}
info!("[Agent Orchestrator]: response received");
} else {
if let Some(tools) = request.tools.as_ref() {
let mut hallucination_state = HallucinationState::new(tools);
let mut has_tool_calls = None;
let mut has_hallucination = false;
} else if let Some(tools) = request.tools.as_ref() {
let mut hallucination_state = HallucinationState::new(tools);
let mut has_tool_calls = None;
let mut has_hallucination = false;
while let Some(chunk_result) = stream.next().await {
let chunk = chunk_result.map_err(|e| FunctionCallingError::InvalidModelResponse(e))?;
while let Some(chunk_result) = stream.next().await {
let chunk = chunk_result.map_err(FunctionCallingError::InvalidModelResponse)?;
// Extract content and logprobs from JSON response
if let Some(choices) = chunk.get("choices").and_then(|v| v.as_array()) {
if let Some(choice) = choices.first() {
if let Some(content) = choice.get("delta")
.and_then(|d| d.get("content"))
.and_then(|c| c.as_str()) {
// Extract content and logprobs from JSON response
if let Some(choices) = chunk.get("choices").and_then(|v| v.as_array()) {
if let Some(choice) = choices.first() {
if let Some(content) = choice
.get("delta")
.and_then(|d| d.get("content"))
.and_then(|c| c.as_str())
{
// Extract logprobs
let logprobs: Vec<f64> = choice
.get("logprobs")
.and_then(|lp| lp.get("content"))
.and_then(|c| c.as_array())
.and_then(|arr| arr.first())
.and_then(|token| token.get("top_logprobs"))
.and_then(|tlp| tlp.as_array())
.map(|arr| {
arr.iter()
.filter_map(|v| v.get("logprob").and_then(|lp| lp.as_f64()))
.collect()
})
.unwrap_or_default();
// Extract logprobs
let logprobs: Vec<f64> = choice.get("logprobs")
.and_then(|lp| lp.get("content"))
.and_then(|c| c.as_array())
.and_then(|arr| arr.first())
.and_then(|token| token.get("top_logprobs"))
.and_then(|tlp| tlp.as_array())
.map(|arr| {
arr.iter()
.filter_map(|v| v.get("logprob").and_then(|lp| lp.as_f64()))
.collect()
})
.unwrap_or_default();
if hallucination_state
.append_and_check_token_hallucination(content.to_string(), logprobs)
{
has_hallucination = true;
break;
}
if hallucination_state.append_and_check_token_hallucination(content.to_string(), logprobs) {
has_hallucination = true;
break;
}
if hallucination_state.tokens.len() > 5 && has_tool_calls.is_none() {
let collected_content = hallucination_state.tokens.join("");
has_tool_calls = Some(collected_content.contains("tool_calls"));
}
if hallucination_state.tokens.len() > 5 && has_tool_calls.is_none() {
let collected_content = hallucination_state.tokens.join("");
has_tool_calls = Some(collected_content.contains("tool_calls"));
}
}
}
}
}
if has_tool_calls == Some(true) && has_hallucination {
info!("[Hallucination]: {}", hallucination_state.error_message);
if has_tool_calls == Some(true) && has_hallucination {
info!("[Hallucination]: {}", hallucination_state.error_message);
let clarify_messages = self.prefill_message(messages.clone(), &self.clarify_prefix);
let clarify_request = self.create_request_with_extra_body(clarify_messages, false);
let clarify_messages = self.prefill_message(messages.clone(), &self.clarify_prefix);
let clarify_request = self.create_request_with_extra_body(clarify_messages, false);
let retry_response = self.make_non_streaming_request(clarify_request).await?;
let retry_response = self.make_non_streaming_request(clarify_request).await?;
if let Some(choice) = retry_response.choices.first() {
if let Some(content) = &choice.message.content {
model_response = content.clone();
}
if let Some(choice) = retry_response.choices.first() {
if let Some(content) = &choice.message.content {
model_response = content.clone();
}
} else {
model_response = hallucination_state.tokens.join("");
}
} else {
while let Some(chunk_result) = stream.next().await {
let chunk = chunk_result.map_err(|e| FunctionCallingError::InvalidModelResponse(e))?;
if let Some(choices) = chunk.get("choices").and_then(|v| v.as_array()) {
if let Some(choice) = choices.first() {
if let Some(content) = choice.get("delta")
.and_then(|d| d.get("content"))
.and_then(|c| c.as_str()) {
model_response.push_str(content);
}
model_response = hallucination_state.tokens.join("");
}
} else {
while let Some(chunk_result) = stream.next().await {
let chunk = chunk_result.map_err(FunctionCallingError::InvalidModelResponse)?;
if let Some(choices) = chunk.get("choices").and_then(|v| v.as_array()) {
if let Some(choice) = choices.first() {
if let Some(content) = choice
.get("delta")
.and_then(|d| d.get("content"))
.and_then(|c| c.as_str())
{
model_response.push_str(content);
}
}
}
@ -1009,10 +1072,17 @@ impl ArchFunctionHandler {
let response_dict = self.parse_model_response(&model_response);
info!("[arch-fc]: raw model response: {}", response_dict.raw_response);
info!(
"[arch-fc]: raw model response: {}",
response_dict.raw_response
);
// General model response (no intent matched - should route to default target)
let model_message = if response_dict.response.as_ref().map_or(false, |s| !s.is_empty()) {
let model_message = if response_dict
.response
.as_ref()
.is_some_and(|s| !s.is_empty())
{
// When arch-fc returns a "response" field, it means no intent was matched
// Return empty content and empty tool_calls so prompt_gateway routes to default target
ResponseMessage {
@ -1053,8 +1123,11 @@ impl ArchFunctionHandler {
let verification = self.verify_tool_calls(tools, &response_dict.tool_calls);
if verification.is_valid {
info!("[Tool calls]: {:?}",
response_dict.tool_calls.iter()
info!(
"[Tool calls]: {:?}",
response_dict
.tool_calls
.iter()
.map(|tc| &tc.function)
.collect::<Vec<_>>()
);
@ -1092,8 +1165,11 @@ impl ArchFunctionHandler {
}
}
} else {
info!("[Tool calls]: {:?}",
response_dict.tool_calls.iter()
info!(
"[Tool calls]: {:?}",
response_dict
.tool_calls
.iter()
.map(|tc| &tc.function)
.collect::<Vec<_>>()
);
@ -1108,7 +1184,10 @@ impl ArchFunctionHandler {
}
}
} else {
error!("Invalid tool calls in response: {}", response_dict.error_message);
error!(
"Invalid tool calls in response: {}",
response_dict.error_message
);
ResponseMessage {
role: Role::Assistant,
content: Some(String::new()),
@ -1243,7 +1322,6 @@ pub async fn function_calling_chat_handler(
req: Request<Incoming>,
llm_provider_url: String,
) -> std::result::Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
use hermesllm::apis::openai::ChatCompletionsRequest;
let whole_body = req.collect().await?.to_bytes();
@ -1255,10 +1333,13 @@ pub async fn function_calling_chat_handler(
let mut response = Response::new(full(
serde_json::json!({
"error": format!("Invalid request body: {}", e)
}).to_string()
})
.to_string(),
));
*response.status_mut() = StatusCode::BAD_REQUEST;
response.headers_mut().insert("Content-Type", "application/json".parse().unwrap());
response
.headers_mut()
.insert("Content-Type", "application/json".parse().unwrap());
return Ok(response);
}
};
@ -1271,24 +1352,31 @@ pub async fn function_calling_chat_handler(
// Parse as ChatCompletionsRequest
let chat_request: ChatCompletionsRequest = match serde_json::from_value(body_json) {
Ok(req) => {
info!("[request body]: {}", serde_json::to_string(&req).unwrap_or_default());
info!(
"[request body]: {}",
serde_json::to_string(&req).unwrap_or_default()
);
req
},
}
Err(e) => {
error!("Failed to parse request body: {}", e);
let mut response = Response::new(full(
serde_json::json!({
"error": format!("Invalid request body: {}", e)
}).to_string()
})
.to_string(),
));
*response.status_mut() = StatusCode::BAD_REQUEST;
response.headers_mut().insert("Content-Type", "application/json".parse().unwrap());
response
.headers_mut()
.insert("Content-Type", "application/json".parse().unwrap());
return Ok(response);
}
};
// Determine which handler to use based on metadata
let use_agent_orchestrator = chat_request.metadata
let use_agent_orchestrator = chat_request
.metadata
.as_ref()
.and_then(|m| m.get("use_agent_orchestrator"))
.and_then(|v| v.as_bool())
@ -1309,7 +1397,10 @@ pub async fn function_calling_chat_handler(
ARCH_FUNCTION_MODEL_NAME.to_string(),
llm_provider_url.clone(),
);
handler.function_handler.function_calling_chat(chat_request).await
handler
.function_handler
.function_calling_chat(chat_request)
.await
} else {
let handler = ArchFunctionHandler::new(
ARCH_FUNCTION_MODEL_NAME.to_string(),
@ -1328,7 +1419,9 @@ pub async fn function_calling_chat_handler(
let mut response = Response::new(full(response_json));
*response.status_mut() = StatusCode::OK;
response.headers_mut().insert("Content-Type", "application/json".parse().unwrap());
response
.headers_mut()
.insert("Content-Type", "application/json".parse().unwrap());
Ok(response)
}
@ -1341,13 +1434,14 @@ pub async fn function_calling_chat_handler(
let mut response = Response::new(full(error_response.to_string()));
*response.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
response.headers_mut().insert("Content-Type", "application/json".parse().unwrap());
response
.headers_mut()
.insert("Content-Type", "application/json".parse().unwrap());
Ok(response)
}
}
}
// ============================================================================
// TESTS
// ============================================================================
@ -1370,10 +1464,13 @@ mod tests {
assert!(config.task_prompt.contains("</tools>\\n\\n"));
// Format prompt should contain literal escaped newlines and proper JSON examples
assert!(config.format_prompt.contains("\\n\\nBased on your analysis"));
assert!(config.format_prompt.contains(r#"{\"response\": \"Your response text here\"}"#));
assert!(config
.format_prompt
.contains("\\n\\nBased on your analysis"));
assert!(config
.format_prompt
.contains(r#"{\"response\": \"Your response text here\"}"#));
assert!(config.format_prompt.contains(r#"{\"tool_calls\": [{"#));
}
#[test]
@ -1384,7 +1481,11 @@ mod tests {
#[test]
fn test_fix_json_string_valid() {
let handler = ArchFunctionHandler::new("test-model".to_string(), ArchFunctionConfig::default(), "http://localhost:8000".to_string());
let handler = ArchFunctionHandler::new(
"test-model".to_string(),
ArchFunctionConfig::default(),
"http://localhost:8000".to_string(),
);
let json_str = r#"{"name": "test", "value": 123}"#;
let result = handler.fix_json_string(json_str);
assert!(result.is_ok());
@ -1392,7 +1493,11 @@ mod tests {
#[test]
fn test_fix_json_string_missing_bracket() {
let handler = ArchFunctionHandler::new("test-model".to_string(), ArchFunctionConfig::default(), "http://localhost:8000".to_string());
let handler = ArchFunctionHandler::new(
"test-model".to_string(),
ArchFunctionConfig::default(),
"http://localhost:8000".to_string(),
);
let json_str = r#"{"name": "test", "value": 123"#;
let result = handler.fix_json_string(json_str);
assert!(result.is_ok());
@ -1402,8 +1507,13 @@ mod tests {
#[test]
fn test_parse_model_response_with_tool_calls() {
let handler = ArchFunctionHandler::new("test-model".to_string(), ArchFunctionConfig::default(), "http://localhost:8000".to_string());
let content = r#"{"tool_calls": [{"name": "get_weather", "arguments": {"location": "NYC"}}]}"#;
let handler = ArchFunctionHandler::new(
"test-model".to_string(),
ArchFunctionConfig::default(),
"http://localhost:8000".to_string(),
);
let content =
r#"{"tool_calls": [{"name": "get_weather", "arguments": {"location": "NYC"}}]}"#;
let result = handler.parse_model_response(content);
assert!(result.is_valid);
@ -1413,8 +1523,13 @@ mod tests {
#[test]
fn test_parse_model_response_with_clarification() {
let handler = ArchFunctionHandler::new("test-model".to_string(), ArchFunctionConfig::default(), "http://localhost:8000".to_string());
let content = r#"{"required_functions": ["get_weather"], "clarification": "What location?"}"#;
let handler = ArchFunctionHandler::new(
"test-model".to_string(),
ArchFunctionConfig::default(),
"http://localhost:8000".to_string(),
);
let content =
r#"{"required_functions": ["get_weather"], "clarification": "What location?"}"#;
let result = handler.parse_model_response(content);
assert!(result.is_valid);
@ -1424,7 +1539,11 @@ mod tests {
#[test]
fn test_convert_data_type_int_to_float() {
let handler = ArchFunctionHandler::new("test-model".to_string(), ArchFunctionConfig::default(), "http://localhost:8000".to_string());
let handler = ArchFunctionHandler::new(
"test-model".to_string(),
ArchFunctionConfig::default(),
"http://localhost:8000".to_string(),
);
let value = json!(42);
let result = handler.convert_data_type(&value, "float");
assert!(result.is_ok());
@ -1504,13 +1623,12 @@ pub fn check_threshold(
}
/// Checks if a parameter is required in the function description
pub fn is_parameter_required(
function_description: &Value,
parameter_name: &str,
) -> bool {
pub fn is_parameter_required(function_description: &Value, parameter_name: &str) -> bool {
if let Some(required) = function_description.get("required") {
if let Some(required_arr) = required.as_array() {
return required_arr.iter().any(|v| v.as_str() == Some(parameter_name));
return required_arr
.iter()
.any(|v| v.as_str() == Some(parameter_name));
}
}
false
@ -1559,12 +1677,7 @@ impl HallucinationState {
pub fn new(functions: &[Tool]) -> Self {
let function_properties: HashMap<String, Value> = functions
.iter()
.map(|tool| {
(
tool.function.name.clone(),
tool.function.parameters.clone(),
)
})
.map(|tool| (tool.function.name.clone(), tool.function.parameters.clone()))
.collect();
Self {
@ -1620,7 +1733,10 @@ impl HallucinationState {
// Function name extraction logic
if self.state.as_deref() == Some("function_name") {
if !FUNC_NAME_END_TOKEN.iter().any(|&t| self.tokens.last().map_or(false, |tok| tok == t)) {
if !FUNC_NAME_END_TOKEN
.iter()
.any(|&t| self.tokens.last().is_some_and(|tok| tok == t))
{
self.mask.push(MaskToken::FunctionName);
} else {
self.state = None;
@ -1629,34 +1745,51 @@ impl HallucinationState {
}
// Check for function name start
if FUNC_NAME_START_PATTERN.iter().any(|&p| content.ends_with(p)) {
if FUNC_NAME_START_PATTERN
.iter()
.any(|&p| content.ends_with(p))
{
self.state = Some("function_name".to_string());
}
// Parameter name extraction logic
if self.state.as_deref() == Some("parameter_name")
&& !PARAMETER_NAME_END_TOKENS.iter().any(|&t| content.ends_with(t)) {
&& !PARAMETER_NAME_END_TOKENS
.iter()
.any(|&t| content.ends_with(t))
{
self.mask.push(MaskToken::ParameterName);
} else if self.state.as_deref() == Some("parameter_name")
&& PARAMETER_NAME_END_TOKENS.iter().any(|&t| content.ends_with(t)) {
&& PARAMETER_NAME_END_TOKENS
.iter()
.any(|&t| content.ends_with(t))
{
self.state = None;
self.parameter_name_done = true;
self.get_parameter_name();
} else if self.parameter_name_done
&& !self.open_bracket
&& PARAMETER_NAME_START_PATTERN.iter().any(|&p| content.ends_with(p)) {
&& PARAMETER_NAME_START_PATTERN
.iter()
.any(|&p| content.ends_with(p))
{
self.state = Some("parameter_name".to_string());
}
// First parameter value start
if FIRST_PARAM_NAME_START_PATTERN.iter().any(|&p| content.ends_with(p)) {
if FIRST_PARAM_NAME_START_PATTERN
.iter()
.any(|&p| content.ends_with(p))
{
self.state = Some("parameter_name".to_string());
}
// Parameter value extraction logic
if self.state.as_deref() == Some("parameter_value")
&& !PARAMETER_VALUE_END_TOKEN.iter().any(|&t| content.ends_with(t)) {
&& !PARAMETER_VALUE_END_TOKEN
.iter()
.any(|&t| content.ends_with(t))
{
// Check for brackets
if let Some(last_token) = self.tokens.last() {
let open_brackets: Vec<char> = last_token
@ -1694,8 +1827,11 @@ impl HallucinationState {
&& self.mask[self.mask.len() - 2] != MaskToken::ParameterValue
&& !self.parameter_name.is_empty()
{
let last_param = self.parameter_name[self.parameter_name.len() - 1].clone();
if let Some(func_props) = self.function_properties.get(&self.function_name) {
let last_param =
self.parameter_name[self.parameter_name.len() - 1].clone();
if let Some(func_props) =
self.function_properties.get(&self.function_name)
{
if is_parameter_required(func_props, &last_param)
&& !is_parameter_property(func_props, &last_param, "enum")
&& !self.check_parameter_name.contains_key(&last_param)
@ -1718,10 +1854,16 @@ impl HallucinationState {
}
} else if self.state.as_deref() == Some("parameter_value")
&& !self.open_bracket
&& PARAMETER_VALUE_END_TOKEN.iter().any(|&t| content.ends_with(t)) {
&& PARAMETER_VALUE_END_TOKEN
.iter()
.any(|&t| content.ends_with(t))
{
self.state = None;
} else if self.parameter_name_done
&& PARAMETER_VALUE_START_PATTERN.iter().any(|&p| content.ends_with(p)) {
&& PARAMETER_VALUE_START_PATTERN
.iter()
.any(|&p| content.ends_with(p))
{
self.state = Some("parameter_value".to_string());
}
@ -1848,18 +1990,18 @@ mod hallucination_tests {
let handler = ArchFunctionHandler::new(
"test-model".to_string(),
ArchFunctionConfig::default(),
"http://localhost:8000".to_string()
"http://localhost:8000".to_string(),
);
// Test integer types
assert!(handler.check_value_type(&json!(42), "integer"));
assert!(handler.check_value_type(&json!(42), "int"));
assert!(!handler.check_value_type(&json!(3.14), "integer"));
assert!(!handler.check_value_type(&json!(3.15), "integer"));
// Test number types (accepts both int and float)
assert!(handler.check_value_type(&json!(3.14), "number"));
assert!(handler.check_value_type(&json!(3.15), "number"));
assert!(handler.check_value_type(&json!(42), "number"));
assert!(handler.check_value_type(&json!(3.14), "float"));
assert!(handler.check_value_type(&json!(3.15), "float"));
// Test boolean
assert!(handler.check_value_type(&json!(true), "boolean"));
@ -1890,12 +2032,16 @@ mod hallucination_tests {
let handler = ArchFunctionHandler::new(
"test-model".to_string(),
ArchFunctionConfig::default(),
"http://localhost:8000".to_string()
"http://localhost:8000".to_string(),
);
// Test valid type - no conversion needed
assert!(handler.validate_or_convert_parameter(&json!(42), "integer").unwrap());
assert!(handler.validate_or_convert_parameter(&json!("hello"), "string").unwrap());
assert!(handler
.validate_or_convert_parameter(&json!(42), "integer")
.unwrap());
assert!(handler
.validate_or_convert_parameter(&json!("hello"), "string")
.unwrap());
// Test integer to float conversion (convert_data_type supports this)
let result = handler.validate_or_convert_parameter(&json!(42), "float");
@ -1910,8 +2056,12 @@ mod hallucination_tests {
assert!(!result.unwrap());
// Test number accepting both int and float
assert!(handler.validate_or_convert_parameter(&json!(42), "number").unwrap());
assert!(handler.validate_or_convert_parameter(&json!(3.14), "number").unwrap());
assert!(handler
.validate_or_convert_parameter(&json!(42), "number")
.unwrap());
assert!(handler
.validate_or_convert_parameter(&json!(3.15), "number")
.unwrap());
}
#[test]

View file

@ -14,7 +14,7 @@ use crate::router::plano_orchestrator::OrchestratorService;
/// 2. PipelineProcessor - executes the agent pipeline
/// 3. ResponseHandler - handles response streaming
#[cfg(test)]
mod integration_tests {
mod tests {
use super::*;
use common::configuration::{Agent, AgentFilterChain, Listener};
@ -62,7 +62,10 @@ mod integration_tests {
let agent_pipeline = AgentFilterChain {
id: "terminal-agent".to_string(),
filter_chain: Some(vec!["filter-agent".to_string(), "terminal-agent".to_string()]),
filter_chain: Some(vec![
"filter-agent".to_string(),
"terminal-agent".to_string(),
]),
description: Some("Test pipeline".to_string()),
default: Some(true),
};

View file

@ -2,48 +2,48 @@ use serde::{Deserialize, Serialize};
use std::collections::HashMap;
pub const JSON_RPC_VERSION: &str = "2.0";
pub const TOOL_CALL_METHOD : &str = "tools/call";
pub const TOOL_CALL_METHOD: &str = "tools/call";
pub const MCP_INITIALIZE: &str = "initialize";
pub const MCP_INITIALIZE_NOTIFICATION: &str = "notifications/initialized";
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum JsonRpcId {
String(String),
Number(u64),
String(String),
Number(u64),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JsonRpcRequest {
pub jsonrpc: String,
pub id: JsonRpcId,
pub method: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub params: Option<HashMap<String, serde_json::Value>>,
pub jsonrpc: String,
pub id: JsonRpcId,
pub method: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub params: Option<HashMap<String, serde_json::Value>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JsonRpcNotification {
pub jsonrpc: String,
pub method: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub params: Option<HashMap<String, serde_json::Value>>,
pub jsonrpc: String,
pub method: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub params: Option<HashMap<String, serde_json::Value>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JsonRpcError {
pub code: i32,
pub message: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub data: Option<serde_json::Value>,
pub code: i32,
pub message: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub data: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JsonRpcResponse {
pub jsonrpc: String,
pub id: JsonRpcId,
#[serde(skip_serializing_if = "Option::is_none")]
pub result: Option<HashMap<String, serde_json::Value>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub error: Option<JsonRpcError>,
pub jsonrpc: String,
pub id: JsonRpcId,
#[serde(skip_serializing_if = "Option::is_none")]
pub result: Option<HashMap<String, serde_json::Value>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub error: Option<JsonRpcError>,
}

View file

@ -1,6 +1,8 @@
use bytes::Bytes;
use common::configuration::{LlmProvider, ModelAlias};
use common::consts::{ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER, REQUEST_ID_HEADER, TRACE_PARENT_HEADER};
use common::consts::{
ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER, REQUEST_ID_HEADER, TRACE_PARENT_HEADER,
};
use common::traces::TraceCollector;
use hermesllm::apis::openai_responses::InputParam;
use hermesllm::clients::{SupportedAPIsFromClient, SupportedUpstreamAPIs};
@ -14,13 +16,14 @@ use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::{debug, info, warn};
use crate::router::llm_router::RouterService;
use crate::handlers::utils::{create_streaming_response, ObservableStreamProcessor, truncate_message};
use crate::handlers::router_chat::router_chat_get_upstream_model;
use crate::handlers::utils::{
create_streaming_response, truncate_message, ObservableStreamProcessor,
};
use crate::router::llm_router::RouterService;
use crate::state::response_state_processor::ResponsesStateProcessor;
use crate::state::{
StateStorage, StateStorageError,
extract_input_items, retrieve_and_combine_input
extract_input_items, retrieve_and_combine_input, StateStorage, StateStorageError,
};
use crate::tracing::operation_component;
@ -39,7 +42,6 @@ pub async fn llm_chat(
trace_collector: Arc<TraceCollector>,
state_storage: Option<Arc<dyn StateStorage>>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
let request_path = request.uri().path().to_string();
let request_headers = request.headers().clone();
let request_id = request_headers
@ -74,8 +76,14 @@ pub async fn llm_chat(
)) {
Ok(request) => request,
Err(err) => {
warn!("[PLANO_REQ_ID:{}] | FAILURE | Failed to parse request as ProviderRequestType: {}", request_id, err);
let err_msg = format!("[PLANO_REQ_ID:{}] | FAILURE | Failed to parse request: {}", request_id, err);
warn!(
"[PLANO_REQ_ID:{}] | FAILURE | Failed to parse request as ProviderRequestType: {}",
request_id, err
);
let err_msg = format!(
"[PLANO_REQ_ID:{}] | FAILURE | Failed to parse request: {}",
request_id, err
);
let mut bad_request = Response::new(full(err_msg));
*bad_request.status_mut() = StatusCode::BAD_REQUEST;
return Ok(bad_request);
@ -85,7 +93,10 @@ pub async fn llm_chat(
// === v1/responses state management: Extract input items early ===
let mut original_input_items = Vec::new();
let client_api = SupportedAPIsFromClient::from_endpoint(request_path.as_str());
let is_responses_api_client = matches!(client_api, Some(SupportedAPIsFromClient::OpenAIResponsesAPI(_)));
let is_responses_api_client = matches!(
client_api,
Some(SupportedAPIsFromClient::OpenAIResponsesAPI(_))
);
// Model alias resolution: update model field in client_request immediately
// This ensures all downstream objects use the resolved model
@ -96,20 +107,28 @@ pub async fn llm_chat(
// Extract tool names and user message preview for span attributes
let tool_names = client_request.get_tool_names();
let user_message_preview = client_request.get_recent_user_message()
let user_message_preview = client_request
.get_recent_user_message()
.map(|msg| truncate_message(&msg, 50));
client_request.set_model(resolved_model.clone());
if client_request.remove_metadata_key("archgw_preference_config") {
debug!("[PLANO_REQ_ID:{}] Removed archgw_preference_config from metadata", request_id);
debug!(
"[PLANO_REQ_ID:{}] Removed archgw_preference_config from metadata",
request_id
);
}
// === v1/responses state management: Determine upstream API and combine input if needed ===
// Do this BEFORE routing since routing consumes the request
// Only process state if state_storage is configured
let mut should_manage_state = false;
if is_responses_api_client && state_storage.is_some() {
if let ProviderRequestType::ResponsesAPIRequest(ref mut responses_req) = client_request {
if is_responses_api_client {
if let (
ProviderRequestType::ResponsesAPIRequest(ref mut responses_req),
Some(ref state_store),
) = (&mut client_request, &state_storage)
{
// Extract original input once
original_input_items = extract_input_items(&responses_req.input);
@ -120,18 +139,22 @@ pub async fn llm_chat(
&request_path,
&resolved_model,
is_streaming_request,
).await;
)
.await;
let upstream_api = SupportedUpstreamAPIs::from_endpoint(&upstream_path);
// Only manage state if upstream is NOT OpenAIResponsesAPI (needs translation)
should_manage_state = !matches!(upstream_api, Some(SupportedUpstreamAPIs::OpenAIResponsesAPI(_)));
should_manage_state = !matches!(
upstream_api,
Some(SupportedUpstreamAPIs::OpenAIResponsesAPI(_))
);
if should_manage_state {
// Retrieve and combine conversation history if previous_response_id exists
if let Some(ref prev_resp_id) = responses_req.previous_response_id {
match retrieve_and_combine_input(
state_storage.as_ref().unwrap().clone(),
state_store.clone(),
prev_resp_id,
original_input_items, // Pass ownership instead of cloning
)
@ -166,7 +189,10 @@ pub async fn llm_chat(
}
}
} else {
debug!("[PLANO_REQ_ID:{}] | BRIGHT_STAFF | Upstream supports ResponsesAPI natively.", request_id);
debug!(
"[PLANO_REQ_ID:{}] | BRIGHT_STAFF | Upstream supports ResponsesAPI natively.",
request_id
);
}
}
}
@ -177,7 +203,7 @@ pub async fn llm_chat(
// Determine routing using the dedicated router_chat module
let routing_result = match router_chat_get_upstream_model(
router_service,
client_request, // Pass the original request - router_chat will convert it
client_request, // Pass the original request - router_chat will convert it
&request_headers,
trace_collector.clone(),
&traceparent,
@ -257,7 +283,8 @@ pub async fn llm_chat(
user_message_preview,
temperature,
&llm_providers,
).await;
)
.await;
// Create base processor for metrics and tracing
let base_processor = ObservableStreamProcessor::new(
@ -269,7 +296,11 @@ pub async fn llm_chat(
// === v1/responses state management: Wrap with ResponsesStateProcessor ===
// Only wrap if we need to manage state (client is ResponsesAPI AND upstream is NOT ResponsesAPI AND state_storage is configured)
let streaming_response = if should_manage_state && !original_input_items.is_empty() && state_storage.is_some() {
let streaming_response = if let (true, false, Some(state_store)) = (
should_manage_state,
original_input_items.is_empty(),
state_storage,
) {
// Extract Content-Encoding header to handle decompression for state parsing
let content_encoding = response_headers
.get("content-encoding")
@ -279,7 +310,7 @@ pub async fn llm_chat(
// Wrap with state management processor to store state after response completes
let state_processor = ResponsesStateProcessor::new(
base_processor,
state_storage.unwrap(),
state_store,
original_input_items,
resolved_model.clone(),
model_name.clone(),
@ -324,6 +355,7 @@ fn resolve_model_alias(
}
/// Builds the LLM span with all required and optional attributes.
#[allow(clippy::too_many_arguments)]
async fn build_llm_span(
traceparent: &str,
request_path: &str,
@ -337,8 +369,8 @@ async fn build_llm_span(
temperature: Option<f32>,
llm_providers: &Arc<RwLock<Vec<LlmProvider>>>,
) -> common::traces::Span {
use common::traces::{SpanBuilder, SpanKind, parse_traceparent};
use crate::tracing::{http, llm, OperationNameBuilder};
use common::traces::{parse_traceparent, SpanBuilder, SpanKind};
// Calculate the upstream path based on provider configuration
let upstream_path = get_upstream_path(
@ -347,13 +379,14 @@ async fn build_llm_span(
request_path,
resolved_model,
is_streaming,
).await;
)
.await;
// Build operation name showing path transformation if different
let operation_name = if request_path != upstream_path {
OperationNameBuilder::new()
.with_method("POST")
.with_path(&format!("{} >> {}", request_path, upstream_path))
.with_path(format!("{} >> {}", request_path, upstream_path))
.with_target(resolved_model)
.build()
} else {
@ -388,7 +421,8 @@ async fn build_llm_span(
}
if let Some(tools) = tool_names {
let formatted_tools = tools.iter()
let formatted_tools = tools
.iter()
.map(|name| format!("{}(...)", name))
.collect::<Vec<_>>()
.join("\n");
@ -436,8 +470,7 @@ async fn get_provider_info(
// First, try to find by model name or provider name
let provider = providers_lock.iter().find(|p| {
p.model.as_ref().map(|m| m == model_name).unwrap_or(false)
|| p.name == model_name
p.model.as_ref().map(|m| m == model_name).unwrap_or(false) || p.name == model_name
});
if let Some(provider) = provider {
@ -446,9 +479,7 @@ async fn get_provider_info(
return (provider_id, prefix);
}
let default_provider = providers_lock.iter().find(|p| {
p.default.unwrap_or(false)
});
let default_provider = providers_lock.iter().find(|p| p.default.unwrap_or(false));
if let Some(provider) = default_provider {
let provider_id = provider.provider_interface.to_provider_id();

View file

@ -1,13 +1,13 @@
pub mod agent_chat_completions;
pub mod agent_selector;
pub mod llm;
pub mod router_chat;
pub mod models;
pub mod function_calling;
pub mod jsonrpc;
pub mod llm;
pub mod models;
pub mod pipeline_processor;
pub mod response_handler;
pub mod router_chat;
pub mod utils;
pub mod jsonrpc;
#[cfg(test)]
mod integration_tests;

View file

@ -82,6 +82,7 @@ impl PipelineProcessor {
}
/// Record a span for filter execution
#[allow(clippy::too_many_arguments)]
fn record_filter_span(
&self,
collector: &std::sync::Arc<common::traces::TraceCollector>,
@ -132,6 +133,7 @@ impl PipelineProcessor {
}
/// Record a span for MCP protocol interactions
#[allow(clippy::too_many_arguments)]
fn record_agent_filter_span(
&self,
collector: &std::sync::Arc<common::traces::TraceCollector>,
@ -156,12 +158,12 @@ impl PipelineProcessor {
.build();
let mut span_builder = SpanBuilder::new(&operation_name)
.with_span_id(span_id.unwrap_or_else(|| generate_random_span_id()))
.with_span_id(span_id.unwrap_or_else(generate_random_span_id))
.with_kind(SpanKind::Client)
.with_start_time(start_time)
.with_end_time(end_time)
.with_attribute(http::METHOD, "POST")
.with_attribute(http::TARGET, &format!("/mcp ({})", operation.to_string()))
.with_attribute(http::TARGET, format!("/mcp ({})", operation))
.with_attribute("mcp.operation", operation.to_string())
.with_attribute("mcp.agent_id", agent_id.to_string())
.with_attribute(
@ -188,6 +190,7 @@ impl PipelineProcessor {
}
/// Process the filter chain of agents (all except the terminal agent)
#[allow(clippy::too_many_arguments)]
pub async fn process_filter_chain(
&mut self,
chat_history: &[Message],
@ -1023,7 +1026,7 @@ mod tests {
}
});
let sse_body = format!("event: message\ndata: {}\n\n", rpc_body.to_string());
let sse_body = format!("event: message\ndata: {}\n\n", rpc_body);
let mut server = Server::new_async().await;
let _m = server
@ -1061,10 +1064,10 @@ mod tests {
.await;
match result {
Err(PipelineError::ClientError { status, body, .. }) => {
assert_eq!(status, 400);
assert_eq!(body, "bad tool call");
}
Err(PipelineError::ClientError { status, body, .. }) => {
assert_eq!(status, 400);
assert_eq!(body, "bad tool call");
}
_ => panic!("Expected client error when isError flag is set"),
}
}

View file

@ -133,9 +133,7 @@ impl ResponseHandler {
let response_headers = llm_response.headers();
let is_sse_streaming = response_headers
.get(hyper::header::CONTENT_TYPE)
.map_or(false, |v| {
v.to_str().unwrap_or("").contains("text/event-stream")
});
.is_some_and(|v| v.to_str().unwrap_or("").contains("text/event-stream"));
let response_bytes = llm_response
.bytes()
@ -164,7 +162,7 @@ impl ResponseHandler {
match transformed_event.provider_response() {
Ok(provider_response) => {
if let Some(content) = provider_response.content_delta() {
accumulated_text.push_str(&content);
accumulated_text.push_str(content);
} else {
info!("No content delta in provider response");
}
@ -174,7 +172,7 @@ impl ResponseHandler {
}
}
}
return Ok(accumulated_text);
Ok(accumulated_text)
} else {
// If not SSE, treat as regular text response
let response_text = String::from_utf8(response_bytes.to_vec()).map_err(|e| {

View file

@ -1,6 +1,6 @@
use common::configuration::ModelUsagePreference;
use common::consts::{REQUEST_ID_HEADER};
use common::traces::{TraceCollector, SpanKind, SpanBuilder, parse_traceparent};
use common::consts::REQUEST_ID_HEADER;
use common::traces::{parse_traceparent, SpanBuilder, SpanKind, TraceCollector};
use hermesllm::clients::endpoints::SupportedUpstreamAPIs;
use hermesllm::{ProviderRequest, ProviderRequestType};
use hyper::StatusCode;
@ -9,10 +9,10 @@ use std::sync::Arc;
use tracing::{debug, info, warn};
use crate::router::llm_router::RouterService;
use crate::tracing::{OperationNameBuilder, operation_component, http, routing};
use crate::tracing::{http, operation_component, routing, OperationNameBuilder};
pub struct RoutingResult {
pub model_name: String
pub model_name: String,
}
pub struct RoutingError {
@ -24,7 +24,7 @@ impl RoutingError {
pub fn internal_error(message: String) -> Self {
Self {
message,
status_code: StatusCode::INTERNAL_SERVER_ERROR
status_code: StatusCode::INTERNAL_SERVER_ERROR,
}
}
}
@ -52,9 +52,7 @@ pub async fn router_chat_get_upstream_model(
// Convert to ChatCompletionsRequest for routing (regardless of input type)
let chat_request = match ProviderRequestType::try_from((
client_request,
&SupportedUpstreamAPIs::OpenAIChatCompletions(
hermesllm::apis::OpenAIApi::ChatCompletions,
),
&SupportedUpstreamAPIs::OpenAIChatCompletions(hermesllm::apis::OpenAIApi::ChatCompletions),
)) {
Ok(ProviderRequestType::ChatCompletionsRequest(req)) => req,
Ok(
@ -69,7 +67,10 @@ pub async fn router_chat_get_upstream_model(
));
}
Err(err) => {
warn!("Failed to convert request to ChatCompletionsRequest: {}", err);
warn!(
"Failed to convert request to ChatCompletionsRequest: {}",
err
);
return Err(RoutingError::internal_error(format!(
"Failed to convert request: {}",
err
@ -151,9 +152,7 @@ pub async fn router_chat_get_upstream_model(
)
.await;
Ok(RoutingResult {
model_name
})
Ok(RoutingResult { model_name })
}
None => {
// No route determined, use default model from request
@ -176,7 +175,7 @@ pub async fn router_chat_get_upstream_model(
.await;
Ok(RoutingResult {
model_name: default_model
model_name: default_model,
})
}
},
@ -194,9 +193,10 @@ pub async fn router_chat_get_upstream_model(
)
.await;
Err(RoutingError::internal_error(
format!("Failed to determine route: {}", err)
))
Err(RoutingError::internal_error(format!(
"Failed to determine route: {}",
err
)))
}
}
}
@ -230,7 +230,10 @@ async fn record_routing_span(
.with_end_time(std::time::SystemTime::now())
.with_attribute(http::METHOD, "POST")
.with_attribute(http::TARGET, routing_api_path.to_string())
.with_attribute(routing::ROUTE_DETERMINATION_MS, start_time.elapsed().as_millis().to_string());
.with_attribute(
routing::ROUTE_DETERMINATION_MS,
start_time.elapsed().as_millis().to_string(),
);
// Only set parent span ID if it exists (not a root span)
if let Some(parent) = parent_span_id {

View file

@ -1,5 +1,5 @@
use bytes::Bytes;
use common::traces::{Span, Attribute, AttributeValue, TraceCollector, Event};
use common::traces::{Attribute, AttributeValue, Event, Span, TraceCollector};
use http_body_util::combinators::BoxBody;
use http_body_util::StreamBody;
use hyper::body::Frame;
@ -11,7 +11,7 @@ use tokio_stream::StreamExt;
use tracing::warn;
// Import tracing constants
use crate::tracing::{llm, error};
use crate::tracing::{error, llm};
/// Trait for processing streaming chunks
/// Implementors can inject custom logic during streaming (e.g., hallucination detection, logging)
@ -97,7 +97,6 @@ impl StreamProcessor for ObservableStreamProcessor {
},
});
self.span.attributes.push(Attribute {
key: llm::DURATION_MS.to_string(),
value: AttributeValue {
@ -119,11 +118,9 @@ impl StreamProcessor for ObservableStreamProcessor {
if let Ok(start_time_nanos) = self.span.start_time_unix_nano.parse::<u128>() {
// Convert ttft from milliseconds to nanoseconds and add to start time
let event_timestamp = start_time_nanos + (ttft * 1_000_000);
let mut event = Event::new(llm::TIME_TO_FIRST_TOKEN_MS.to_string(), event_timestamp);
event.add_attribute(
llm::TIME_TO_FIRST_TOKEN_MS.to_string(),
ttft.to_string(),
);
let mut event =
Event::new(llm::TIME_TO_FIRST_TOKEN_MS.to_string(), event_timestamp);
event.add_attribute(llm::TIME_TO_FIRST_TOKEN_MS.to_string(), ttft.to_string());
// Initialize events vector if needed
if self.span.events.is_none() {
@ -137,7 +134,8 @@ impl StreamProcessor for ObservableStreamProcessor {
}
// Record the finalized span
self.collector.record_span(&self.service_name, self.span.clone());
self.collector
.record_span(&self.service_name, self.span.clone());
}
fn on_error(&mut self, error_msg: &str) {
@ -173,7 +171,8 @@ impl StreamProcessor for ObservableStreamProcessor {
});
// Record the error span
self.collector.record_span(&self.service_name, self.span.clone());
self.collector
.record_span(&self.service_name, self.span.clone());
}
}

View file

@ -4,13 +4,15 @@ use brightstaff::handlers::llm::llm_chat;
use brightstaff::handlers::models::list_models;
use brightstaff::router::llm_router::RouterService;
use brightstaff::router::plano_orchestrator::OrchestratorService;
use brightstaff::state::StateStorage;
use brightstaff::state::postgresql::PostgreSQLConversationStorage;
use brightstaff::state::memory::MemoryConversationalStorage;
use brightstaff::state::postgresql::PostgreSQLConversationStorage;
use brightstaff::state::StateStorage;
use brightstaff::utils::tracing::init_tracer;
use bytes::Bytes;
use common::configuration::{Agent, Configuration};
use common::consts::{CHAT_COMPLETIONS_PATH, MESSAGES_PATH, OPENAI_RESPONSES_API_PATH, PLANO_ORCHESTRATOR_MODEL_NAME};
use common::consts::{
CHAT_COMPLETIONS_PATH, MESSAGES_PATH, OPENAI_RESPONSES_API_PATH, PLANO_ORCHESTRATOR_MODEL_NAME,
};
use common::traces::TraceCollector;
use http_body_util::{combinators::BoxBody, BodyExt, Empty};
use hyper::body::Incoming;
@ -105,7 +107,6 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
PLANO_ORCHESTRATOR_MODEL_NAME.to_string(),
));
let model_aliases = Arc::new(arch_config.model_aliases.clone());
// Initialize trace collector and start background flusher
@ -127,33 +128,33 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
// Configurable via arch_config.yaml state_storage section
// If not configured, state management is disabled
// Environment variables are substituted by envsubst before config is read
let state_storage: Option<Arc<dyn StateStorage>> = if let Some(storage_config) = &arch_config.state_storage {
let storage: Arc<dyn StateStorage> = match storage_config.storage_type {
common::configuration::StateStorageType::Memory => {
info!("Initialized conversation state storage: Memory");
Arc::new(MemoryConversationalStorage::new())
}
common::configuration::StateStorageType::Postgres => {
let connection_string = storage_config
.connection_string
.as_ref()
.expect("connection_string is required for postgres state_storage");
let state_storage: Option<Arc<dyn StateStorage>> =
if let Some(storage_config) = &arch_config.state_storage {
let storage: Arc<dyn StateStorage> = match storage_config.storage_type {
common::configuration::StateStorageType::Memory => {
info!("Initialized conversation state storage: Memory");
Arc::new(MemoryConversationalStorage::new())
}
common::configuration::StateStorageType::Postgres => {
let connection_string = storage_config
.connection_string
.as_ref()
.expect("connection_string is required for postgres state_storage");
debug!("Postgres connection string (full): {}", connection_string);
info!("Initializing conversation state storage: Postgres");
Arc::new(
PostgreSQLConversationStorage::new(connection_string.clone())
.await
.expect("Failed to initialize Postgres state storage"),
)
}
debug!("Postgres connection string (full): {}", connection_string);
info!("Initializing conversation state storage: Postgres");
Arc::new(
PostgreSQLConversationStorage::new(connection_string.clone())
.await
.expect("Failed to initialize Postgres state storage"),
)
}
};
Some(storage)
} else {
info!("No state_storage configured - conversation state management disabled");
None
};
Some(storage)
} else {
info!("No state_storage configured - conversation state management disabled");
None
};
loop {
let (stream, _) = listener.accept().await?;
@ -208,12 +209,22 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
}
}
match (req.method(), path) {
(&Method::POST, CHAT_COMPLETIONS_PATH | MESSAGES_PATH | OPENAI_RESPONSES_API_PATH) => {
let fully_qualified_url =
format!("{}{}", llm_provider_url, path);
llm_chat(req, router_service, fully_qualified_url, model_aliases, llm_providers, trace_collector, state_storage)
.with_context(parent_cx)
.await
(
&Method::POST,
CHAT_COMPLETIONS_PATH | MESSAGES_PATH | OPENAI_RESPONSES_API_PATH,
) => {
let fully_qualified_url = format!("{}{}", llm_provider_url, path);
llm_chat(
req,
router_service,
fully_qualified_url,
model_aliases,
llm_providers,
trace_collector,
state_storage,
)
.with_context(parent_cx)
.await
}
(&Method::POST, "/function_calling") => {
let fully_qualified_url =

View file

@ -2,7 +2,7 @@ use std::collections::HashMap;
use common::configuration::{AgentUsagePreference, OrchestrationPreference};
use hermesllm::apis::openai::{ChatCompletionsRequest, Message, MessageContent, Role};
use serde::{Deserialize, Serialize, ser::Serialize as SerializeTrait};
use serde::{ser::Serialize as SerializeTrait, Deserialize, Serialize};
use tracing::{debug, warn};
use super::orchestrator_model::{OrchestratorModel, OrchestratorModelError};
@ -144,7 +144,7 @@ impl OrchestratorModelV1 {
// Format routes: each route as JSON on its own line with standard spacing
let agent_orchestration_json_str = agent_orchestration_values
.iter()
.map(|pref| to_spaced_json(pref))
.map(to_spaced_json)
.collect::<Vec<String>>()
.join("\n");
let agent_orchestration_to_model_map: HashMap<String, String> = agent_orchestrations
@ -238,24 +238,26 @@ impl OrchestratorModel for OrchestratorModelV1 {
let selected_conversation_list = selected_messages_list_reversed
.iter()
.rev()
.map(|message| {
Message {
role: message.role.clone(),
content: MessageContent::Text(message.content.to_string()),
name: None,
tool_calls: None,
tool_call_id: None,
}
.map(|message| Message {
role: message.role.clone(),
content: MessageContent::Text(message.content.to_string()),
name: None,
tool_calls: None,
tool_call_id: None,
})
.collect::<Vec<Message>>();
// Generate the orchestrator request message based on the usage preferences.
// If preferences are passed in request then we use them;
// Otherwise, we use the default orchestration modelpreferences.
let orchestrator_message = match convert_to_orchestrator_preferences(usage_preferences_from_request) {
Some(prefs) => generate_orchestrator_message(&prefs, &selected_conversation_list),
None => generate_orchestrator_message(&self.agent_orchestration_json_str, &selected_conversation_list),
};
let orchestrator_message =
match convert_to_orchestrator_preferences(usage_preferences_from_request) {
Some(prefs) => generate_orchestrator_message(&prefs, &selected_conversation_list),
None => generate_orchestrator_message(
&self.agent_orchestration_json_str,
&selected_conversation_list,
),
};
ChatCompletionsRequest {
model: self.orchestration_model.clone(),
@ -280,7 +282,8 @@ impl OrchestratorModel for OrchestratorModelV1 {
return Ok(None);
}
let orchestrator_resp_fixed = fix_json_response(content);
let orchestrator_response: AgentOrchestratorResponse = serde_json::from_str(orchestrator_resp_fixed.as_str())?;
let orchestrator_response: AgentOrchestratorResponse =
serde_json::from_str(orchestrator_resp_fixed.as_str())?;
let selected_routes = orchestrator_response.route.unwrap_or_default();
@ -320,7 +323,11 @@ impl OrchestratorModel for OrchestratorModelV1 {
} else {
// If no usage preferences are passed in request then use the default orchestration model preferences
for selected_route in valid_routes {
if let Some(model) = self.agent_orchestration_to_model_map.get(&selected_route).cloned() {
if let Some(model) = self
.agent_orchestration_to_model_map
.get(&selected_route)
.cloned()
{
result.push((selected_route, model));
} else {
warn!(
@ -375,7 +382,7 @@ fn convert_to_orchestrator_preferences(
// Format routes: each route as JSON on its own line with standard spacing
let routes_str = orchestration_preferences
.iter()
.map(|pref| to_spaced_json(pref))
.map(to_spaced_json)
.collect::<Vec<String>>()
.join("\n");
@ -425,7 +432,10 @@ mod tests {
// CRITICAL: Test that colons inside string values are NOT modified
let with_colon = serde_json::json!({"name": "foo:bar", "url": "http://example.com"});
let result = to_spaced_json(&with_colon);
assert_eq!(result, r#"{"name": "foo:bar", "url": "http://example.com"}"#);
assert_eq!(
result,
r#"{"name": "foo:bar", "url": "http://example.com"}"#
);
// Test empty object and array
let empty_obj = serde_json::json!({});
@ -446,7 +456,8 @@ mod tests {
});
let result = to_spaced_json(&complex);
// Verify URLs with colons are preserved correctly
assert!(result.contains(r#""urls": ["https://api.example.com:8080/path", "file:///local/path"]"#));
assert!(result
.contains(r#""urls": ["https://api.example.com:8080/path", "file:///local/path"]"#));
// Verify spacing format
assert!(result.contains(r#""type": "object""#));
assert!(result.contains(r#""properties": {}"#));
@ -497,10 +508,16 @@ If no routes are needed, return an empty list for `route`.
]
}
"#;
let agent_orchestrations =
serde_json::from_str::<HashMap<String, Vec<OrchestrationPreference>>>(orchestrations_str).unwrap();
let agent_orchestrations = serde_json::from_str::<
HashMap<String, Vec<OrchestrationPreference>>,
>(orchestrations_str)
.unwrap();
let orchestration_model = "test-model".to_string();
let orchestrator = OrchestratorModelV1::new(agent_orchestrations, orchestration_model.clone(), usize::MAX);
let orchestrator = OrchestratorModelV1::new(
agent_orchestrations,
orchestration_model.clone(),
usize::MAX,
);
let conversation_str = r#"
[
@ -568,7 +585,11 @@ If no routes are needed, return an empty list for `route`.
// Empty orchestrations map - not used when usage_preferences are provided
let agent_orchestrations: HashMap<String, Vec<OrchestrationPreference>> = HashMap::new();
let orchestration_model = "test-model".to_string();
let orchestrator = OrchestratorModelV1::new(agent_orchestrations, orchestration_model.clone(), usize::MAX);
let orchestrator = OrchestratorModelV1::new(
agent_orchestrations,
orchestration_model.clone(),
usize::MAX,
);
let conversation_str = r#"
[
@ -640,10 +661,13 @@ If no routes are needed, return an empty list for `route`.
]
}
"#;
let agent_orchestrations =
serde_json::from_str::<HashMap<String, Vec<OrchestrationPreference>>>(orchestrations_str).unwrap();
let agent_orchestrations = serde_json::from_str::<
HashMap<String, Vec<OrchestrationPreference>>,
>(orchestrations_str)
.unwrap();
let orchestration_model = "test-model".to_string();
let orchestrator = OrchestratorModelV1::new(agent_orchestrations, orchestration_model.clone(), 235);
let orchestrator =
OrchestratorModelV1::new(agent_orchestrations, orchestration_model.clone(), 235);
let conversation_str = r#"
[
@ -709,11 +733,14 @@ If no routes are needed, return an empty list for `route`.
]
}
"#;
let agent_orchestrations =
serde_json::from_str::<HashMap<String, Vec<OrchestrationPreference>>>(orchestrations_str).unwrap();
let agent_orchestrations = serde_json::from_str::<
HashMap<String, Vec<OrchestrationPreference>>,
>(orchestrations_str)
.unwrap();
let orchestration_model = "test-model".to_string();
let orchestrator = OrchestratorModelV1::new(agent_orchestrations, orchestration_model.clone(), 200);
let orchestrator =
OrchestratorModelV1::new(agent_orchestrations, orchestration_model.clone(), 200);
let conversation_str = r#"
[
@ -787,10 +814,13 @@ If no routes are needed, return an empty list for `route`.
]
}
"#;
let agent_orchestrations =
serde_json::from_str::<HashMap<String, Vec<OrchestrationPreference>>>(orchestrations_str).unwrap();
let agent_orchestrations = serde_json::from_str::<
HashMap<String, Vec<OrchestrationPreference>>,
>(orchestrations_str)
.unwrap();
let orchestration_model = "test-model".to_string();
let orchestrator = OrchestratorModelV1::new(agent_orchestrations, orchestration_model.clone(), 230);
let orchestrator =
OrchestratorModelV1::new(agent_orchestrations, orchestration_model.clone(), 230);
let conversation_str = r#"
[
@ -871,10 +901,16 @@ If no routes are needed, return an empty list for `route`.
]
}
"#;
let agent_orchestrations =
serde_json::from_str::<HashMap<String, Vec<OrchestrationPreference>>>(orchestrations_str).unwrap();
let agent_orchestrations = serde_json::from_str::<
HashMap<String, Vec<OrchestrationPreference>>,
>(orchestrations_str)
.unwrap();
let orchestration_model = "test-model".to_string();
let orchestrator = OrchestratorModelV1::new(agent_orchestrations, orchestration_model.clone(), usize::MAX);
let orchestrator = OrchestratorModelV1::new(
agent_orchestrations,
orchestration_model.clone(),
usize::MAX,
);
let conversation_str = r#"
[
@ -957,10 +993,16 @@ If no routes are needed, return an empty list for `route`.
]
}
"#;
let agent_orchestrations =
serde_json::from_str::<HashMap<String, Vec<OrchestrationPreference>>>(orchestrations_str).unwrap();
let agent_orchestrations = serde_json::from_str::<
HashMap<String, Vec<OrchestrationPreference>>,
>(orchestrations_str)
.unwrap();
let orchestration_model = "test-model".to_string();
let orchestrator = OrchestratorModelV1::new(agent_orchestrations, orchestration_model.clone(), usize::MAX);
let orchestrator = OrchestratorModelV1::new(
agent_orchestrations,
orchestration_model.clone(),
usize::MAX,
);
let conversation_str = r#"
[
@ -1034,10 +1076,13 @@ If no routes are needed, return an empty list for `route`.
]
}
"#;
let agent_orchestrations =
serde_json::from_str::<HashMap<String, Vec<OrchestrationPreference>>>(orchestrations_str).unwrap();
let agent_orchestrations = serde_json::from_str::<
HashMap<String, Vec<OrchestrationPreference>>,
>(orchestrations_str)
.unwrap();
let orchestrator = OrchestratorModelV1::new(agent_orchestrations, "test-model".to_string(), 2000);
let orchestrator =
OrchestratorModelV1::new(agent_orchestrations, "test-model".to_string(), 2000);
// Case 1: Valid JSON with single route in array
let input = r#"{"route": ["Image generation"]}"#;

View file

@ -34,10 +34,7 @@ pub enum OrchestrationError {
pub type Result<T> = std::result::Result<T, OrchestrationError>;
impl OrchestratorService {
pub fn new(
orchestrator_url: String,
orchestration_model_name: String,
) -> Self {
pub fn new(orchestrator_url: String, orchestration_model_name: String) -> Self {
// Empty agent orchestrations - will be provided via usage_preferences in requests
let agent_orchestrations: HashMap<String, Vec<OrchestrationPreference>> = HashMap::new();

View file

@ -85,13 +85,19 @@ impl StateStorage for MemoryConversationalStorage {
#[cfg(test)]
mod tests {
use super::*;
use hermesllm::apis::openai_responses::{InputItem, InputMessage, MessageRole, InputContent, MessageContent};
use hermesllm::apis::openai_responses::{
InputContent, InputItem, InputMessage, MessageContent, MessageRole,
};
fn create_test_state(response_id: &str, num_messages: usize) -> OpenAIConversationState {
let mut input_items = Vec::new();
for i in 0..num_messages {
input_items.push(InputItem::Message(InputMessage {
role: if i % 2 == 0 { MessageRole::User } else { MessageRole::Assistant },
role: if i % 2 == 0 {
MessageRole::User
} else {
MessageRole::Assistant
},
content: MessageContent::Items(vec![InputContent::InputText {
text: format!("Message {}", i),
}]),
@ -252,7 +258,9 @@ mod tests {
let merged = storage.merge(&prev_state, current_input);
// Verify order: prev messages first, then current
let InputItem::Message(msg) = &merged[0] else { panic!("Expected Message") };
let InputItem::Message(msg) = &merged[0] else {
panic!("Expected Message")
};
match &msg.content {
MessageContent::Items(items) => match &items[0] {
InputContent::InputText { text } => assert_eq!(text, "Message 0"),
@ -261,7 +269,9 @@ mod tests {
_ => panic!("Expected MessageContent::Items"),
}
let InputItem::Message(msg) = &merged[2] else { panic!("Expected Message") };
let InputItem::Message(msg) = &merged[2] else {
panic!("Expected Message")
};
match &msg.content {
MessageContent::Items(items) => match &items[0] {
InputContent::InputText { text } => assert_eq!(text, "Message 2"),
@ -404,7 +414,8 @@ mod tests {
let current_input = vec![InputItem::Message(InputMessage {
role: MessageRole::User,
content: MessageContent::Items(vec![InputContent::InputText {
text: "Function result: {\"temperature\": 72, \"condition\": \"sunny\"}".to_string(),
text: "Function result: {\"temperature\": 72, \"condition\": \"sunny\"}"
.to_string(),
}]),
})];
@ -415,7 +426,9 @@ mod tests {
assert_eq!(merged.len(), 3);
// Verify the order and content
let InputItem::Message(msg1) = &merged[0] else { panic!("Expected Message") };
let InputItem::Message(msg1) = &merged[0] else {
panic!("Expected Message")
};
assert!(matches!(msg1.role, MessageRole::User));
match &msg1.content {
MessageContent::Items(items) => match &items[0] {
@ -427,7 +440,9 @@ mod tests {
_ => panic!("Expected MessageContent::Items"),
}
let InputItem::Message(msg2) = &merged[1] else { panic!("Expected Message") };
let InputItem::Message(msg2) = &merged[1] else {
panic!("Expected Message")
};
assert!(matches!(msg2.role, MessageRole::Assistant));
match &msg2.content {
MessageContent::Items(items) => match &items[0] {
@ -439,7 +454,9 @@ mod tests {
_ => panic!("Expected MessageContent::Items"),
}
let InputItem::Message(msg3) = &merged[2] else { panic!("Expected Message") };
let InputItem::Message(msg3) = &merged[2] else {
panic!("Expected Message")
};
assert!(matches!(msg3.role, MessageRole::User));
match &msg3.content {
MessageContent::Items(items) => match &items[0] {
@ -508,11 +525,15 @@ mod tests {
assert_eq!(merged.len(), 5);
// Verify first item is original user message
let InputItem::Message(first) = &merged[0] else { panic!("Expected Message") };
let InputItem::Message(first) = &merged[0] else {
panic!("Expected Message")
};
assert!(matches!(first.role, MessageRole::User));
// Verify last two are function outputs
let InputItem::Message(second_last) = &merged[3] else { panic!("Expected Message") };
let InputItem::Message(second_last) = &merged[3] else {
panic!("Expected Message")
};
assert!(matches!(second_last.role, MessageRole::User));
match &second_last.content {
MessageContent::Items(items) => match &items[0] {
@ -522,7 +543,9 @@ mod tests {
_ => panic!("Expected MessageContent::Items"),
}
let InputItem::Message(last) = &merged[4] else { panic!("Expected Message") };
let InputItem::Message(last) = &merged[4] else {
panic!("Expected Message")
};
assert!(matches!(last.role, MessageRole::User));
match &last.content {
MessageContent::Items(items) => match &items[0] {
@ -590,7 +613,9 @@ mod tests {
assert_eq!(merged.len(), 5);
// Verify the entire conversation flow is preserved
let InputItem::Message(first) = &merged[0] else { panic!("Expected Message") };
let InputItem::Message(first) = &merged[0] else {
panic!("Expected Message")
};
match &first.content {
MessageContent::Items(items) => match &items[0] {
InputContent::InputText { text } => assert!(text.contains("What's the weather")),
@ -599,7 +624,9 @@ mod tests {
_ => panic!("Expected MessageContent::Items"),
}
let InputItem::Message(last) = &merged[4] else { panic!("Expected Message") };
let InputItem::Message(last) = &merged[4] else {
panic!("Expected Message")
};
match &last.content {
MessageContent::Items(items) => match &items[0] {
InputContent::InputText { text } => assert!(text.contains("umbrella")),

View file

@ -1,14 +1,16 @@
use async_trait::async_trait;
use hermesllm::apis::openai_responses::{InputItem, InputMessage, InputContent, MessageContent, MessageRole, InputParam};
use hermesllm::apis::openai_responses::{
InputContent, InputItem, InputMessage, InputParam, MessageContent, MessageRole,
};
use serde::{Deserialize, Serialize};
use std::error::Error;
use std::fmt;
use std::sync::Arc;
use tracing::{debug};
use tracing::debug;
pub mod memory;
pub mod response_state_processor;
pub mod postgresql;
pub mod response_state_processor;
/// Represents the conversational state for a v1/responses request
/// Contains the complete input/output history that can be restored
@ -47,7 +49,9 @@ pub enum StateStorageError {
impl fmt::Display for StateStorageError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
StateStorageError::NotFound(id) => write!(f, "Conversation state not found for response_id: {}", id),
StateStorageError::NotFound(id) => {
write!(f, "Conversation state not found for response_id: {}", id)
}
StateStorageError::StorageError(msg) => write!(f, "Storage error: {}", msg),
StateStorageError::SerializationError(msg) => write!(f, "Serialization error: {}", msg),
}
@ -96,8 +100,6 @@ pub trait StateStorage: Send + Sync {
}
}
/// Storage backend type enum
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum StorageBackend {
@ -106,7 +108,7 @@ pub enum StorageBackend {
}
impl StorageBackend {
pub fn from_str(s: &str) -> Option<Self> {
pub fn parse_backend(s: &str) -> Option<Self> {
match s.to_lowercase().as_str() {
"memory" => Some(StorageBackend::Memory),
"supabase" => Some(StorageBackend::Supabase),
@ -139,7 +141,6 @@ pub async fn retrieve_and_combine_input(
previous_response_id: &str,
current_input: Vec<InputItem>,
) -> Result<Vec<InputItem>, StateStorageError> {
// First get the previous state
let prev_state = storage.get(previous_response_id).await?;
let combined_input = storage.merge(&prev_state, current_input);

View file

@ -149,13 +149,12 @@ impl StateStorage for PostgreSQLConversationStorage {
let provider: String = row.get("provider");
// Deserialize input_items from JSONB
let input_items =
serde_json::from_value(input_items_json).map_err(|e| {
StateStorageError::StorageError(format!(
"Failed to deserialize input_items: {}",
e
))
})?;
let input_items = serde_json::from_value(input_items_json).map_err(|e| {
StateStorageError::StorageError(format!(
"Failed to deserialize input_items: {}",
e
))
})?;
Ok(OpenAIConversationState {
response_id,
@ -230,7 +229,9 @@ Run that SQL file against your database before using this storage backend.
#[cfg(test)]
mod tests {
use super::*;
use hermesllm::apis::openai_responses::{InputContent, InputItem, InputMessage, MessageContent, MessageRole};
use hermesllm::apis::openai_responses::{
InputContent, InputItem, InputMessage, MessageContent, MessageRole,
};
fn create_test_state(response_id: &str) -> OpenAIConversationState {
OpenAIConversationState {
@ -320,7 +321,10 @@ mod tests {
let result = storage.get("nonexistent_id").await;
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), StateStorageError::NotFound(_)));
assert!(matches!(
result.unwrap_err(),
StateStorageError::NotFound(_)
));
}
#[tokio::test]
@ -372,7 +376,10 @@ mod tests {
let result = storage.delete("nonexistent_id").await;
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), StateStorageError::NotFound(_)));
assert!(matches!(
result.unwrap_err(),
StateStorageError::NotFound(_)
));
}
#[tokio::test]
@ -423,9 +430,13 @@ mod tests {
println!("✅ Data written to Supabase!");
println!("Check your Supabase dashboard:");
println!(" SELECT * FROM conversation_states WHERE response_id = 'manual_test_verification';");
println!(
" SELECT * FROM conversation_states WHERE response_id = 'manual_test_verification';"
);
println!("\nTo cleanup, run:");
println!(" DELETE FROM conversation_states WHERE response_id = 'manual_test_verification';");
println!(
" DELETE FROM conversation_states WHERE response_id = 'manual_test_verification';"
);
// DON'T cleanup - leave it for manual verification
}

View file

@ -1,13 +1,11 @@
use bytes::Bytes;
use flate2::read::GzDecoder;
use hermesllm::apis::openai_responses::{
InputItem, OutputItem, ResponsesAPIStreamEvent,
};
use hermesllm::apis::openai_responses::{InputItem, OutputItem, ResponsesAPIStreamEvent};
use hermesllm::apis::streaming_shapes::sse::SseStreamIter;
use hermesllm::transforms::response::output_to_input::outputs_to_inputs;
use std::io::Read;
use std::sync::Arc;
use tracing::{info, debug, warn};
use tracing::{debug, info, warn};
use crate::handlers::utils::StreamProcessor;
use crate::state::{OpenAIConversationState, StateStorage};
@ -53,6 +51,7 @@ pub struct ResponsesStateProcessor<P: StreamProcessor> {
}
impl<P: StreamProcessor> ResponsesStateProcessor<P> {
#[allow(clippy::too_many_arguments)]
pub fn new(
inner: P,
storage: Arc<dyn StateStorage>,
@ -139,20 +138,19 @@ impl<P: StreamProcessor> ResponsesStateProcessor<P> {
for event in sse_iter {
// Only process data lines (skip event-only lines)
if let Some(data_str) = &event.data {
// Try to parse as ResponsesAPIStreamEvent
if let Ok(stream_event) = serde_json::from_str::<ResponsesAPIStreamEvent>(data_str) {
// Check if this is a ResponseCompleted event
if let ResponsesAPIStreamEvent::ResponseCompleted { response, .. } = stream_event {
info!(
"[PLANO_REQ_ID:{}] | STATE_PROCESSOR | Captured streaming response.completed: response_id={}, output_items={}",
self.request_id,
response.id,
response.output.len()
);
self.response_id = Some(response.id.clone());
self.output_items = Some(response.output.clone());
return; // Found what we need, exit early
}
// Try to parse as ResponsesAPIStreamEvent and check if it's a ResponseCompleted event
if let Ok(ResponsesAPIStreamEvent::ResponseCompleted { response, .. }) =
serde_json::from_str::<ResponsesAPIStreamEvent>(data_str)
{
info!(
"[PLANO_REQ_ID:{}] | STATE_PROCESSOR | Captured streaming response.completed: response_id={}, output_items={}",
self.request_id,
response.id,
response.output.len()
);
self.response_id = Some(response.id.clone());
self.output_items = Some(response.output.clone());
return; // Found what we need, exit early
}
}
}
@ -172,7 +170,9 @@ impl<P: StreamProcessor> ResponsesStateProcessor<P> {
let decompressed = self.decompress_buffer();
// Parse complete JSON response
match serde_json::from_slice::<hermesllm::apis::openai_responses::ResponsesAPIResponse>(&decompressed) {
match serde_json::from_slice::<hermesllm::apis::openai_responses::ResponsesAPIResponse>(
&decompressed,
) {
Ok(response) => {
info!(
"[PLANO_REQ_ID:{}] | STATE_PROCESSOR | Captured non-streaming response: response_id={}, output_items={}",

View file

@ -2,11 +2,9 @@
///
/// This module defines standard attribute keys following OTEL semantic conventions.
/// See: https://opentelemetry.io/docs/specs/semconv/
// =============================================================================
// Span Attributes - HTTP
// =============================================================================
/// Semantic conventions for HTTP-related span attributes
pub mod http {
/// HTTP request method

View file

@ -1,3 +1,3 @@
mod constants;
pub use constants::{OperationNameBuilder, operation_component, http, llm, error, routing};
pub use constants::{error, http, llm, operation_component, routing, OperationNameBuilder};