mirror of
https://github.com/katanemo/plano.git
synced 2026-05-06 22:32:42 +02:00
cargo clippy (#660)
This commit is contained in:
parent
c75e7606f9
commit
ca95ffb63d
62 changed files with 1864 additions and 1187 deletions
|
|
@ -3,7 +3,7 @@ use std::time::{Instant, SystemTime};
|
|||
|
||||
use bytes::Bytes;
|
||||
use common::consts::TRACE_PARENT_HEADER;
|
||||
use common::traces::{SpanBuilder, SpanKind, parse_traceparent, generate_random_span_id};
|
||||
use common::traces::{generate_random_span_id, parse_traceparent, SpanBuilder, SpanKind};
|
||||
use hermesllm::apis::OpenAIMessage;
|
||||
use hermesllm::clients::SupportedAPIsFromClient;
|
||||
use hermesllm::providers::request::ProviderRequest;
|
||||
|
|
@ -18,7 +18,7 @@ use super::agent_selector::{AgentSelectionError, AgentSelector};
|
|||
use super::pipeline_processor::{PipelineError, PipelineProcessor};
|
||||
use super::response_handler::ResponseHandler;
|
||||
use crate::router::plano_orchestrator::OrchestratorService;
|
||||
use crate::tracing::{OperationNameBuilder, operation_component, http};
|
||||
use crate::tracing::{http, operation_component, OperationNameBuilder};
|
||||
|
||||
/// Main errors for agent chat completions
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
|
|
@ -61,7 +61,6 @@ pub async fn agent_chat(
|
|||
body,
|
||||
}) = &err
|
||||
{
|
||||
|
||||
warn!(
|
||||
"Client error from agent '{}' (HTTP {}): {}",
|
||||
agent, status, body
|
||||
|
|
@ -77,8 +76,8 @@ pub async fn agent_chat(
|
|||
|
||||
let json_string = error_json.to_string();
|
||||
let mut response = Response::new(ResponseHandler::create_full_body(json_string));
|
||||
*response.status_mut() = hyper::StatusCode::from_u16(*status)
|
||||
.unwrap_or(hyper::StatusCode::BAD_REQUEST);
|
||||
*response.status_mut() =
|
||||
hyper::StatusCode::from_u16(*status).unwrap_or(hyper::StatusCode::BAD_REQUEST);
|
||||
response.headers_mut().insert(
|
||||
hyper::header::CONTENT_TYPE,
|
||||
"application/json".parse().unwrap(),
|
||||
|
|
@ -234,8 +233,18 @@ async fn handle_agent_chat(
|
|||
.with_attribute(http::TARGET, "/agents/select")
|
||||
.with_attribute("selection.listener", listener.name.clone())
|
||||
.with_attribute("selection.agent_count", selected_agents.len().to_string())
|
||||
.with_attribute("selection.agents", selected_agents.iter().map(|a| a.id.as_str()).collect::<Vec<_>>().join(","))
|
||||
.with_attribute("duration_ms", format!("{:.2}", selection_elapsed.as_secs_f64() * 1000.0));
|
||||
.with_attribute(
|
||||
"selection.agents",
|
||||
selected_agents
|
||||
.iter()
|
||||
.map(|a| a.id.as_str())
|
||||
.collect::<Vec<_>>()
|
||||
.join(","),
|
||||
)
|
||||
.with_attribute(
|
||||
"duration_ms",
|
||||
format!("{:.2}", selection_elapsed.as_secs_f64() * 1000.0),
|
||||
);
|
||||
|
||||
if !trace_id.is_empty() {
|
||||
selection_span_builder = selection_span_builder.with_trace_id(trace_id.clone());
|
||||
|
|
@ -318,8 +327,14 @@ async fn handle_agent_chat(
|
|||
.with_attribute(http::METHOD, "POST")
|
||||
.with_attribute(http::TARGET, full_path)
|
||||
.with_attribute("agent.name", agent_name.clone())
|
||||
.with_attribute("agent.sequence", format!("{}/{}", agent_index + 1, agent_count))
|
||||
.with_attribute("duration_ms", format!("{:.2}", agent_elapsed.as_secs_f64() * 1000.0));
|
||||
.with_attribute(
|
||||
"agent.sequence",
|
||||
format!("{}/{}", agent_index + 1, agent_count),
|
||||
)
|
||||
.with_attribute(
|
||||
"duration_ms",
|
||||
format!("{:.2}", agent_elapsed.as_secs_f64() * 1000.0),
|
||||
);
|
||||
|
||||
if !trace_id.is_empty() {
|
||||
span_builder = span_builder.with_trace_id(trace_id.clone());
|
||||
|
|
@ -333,7 +348,10 @@ async fn handle_agent_chat(
|
|||
|
||||
// If this is the last agent, return the streaming response
|
||||
if is_last_agent {
|
||||
info!("Completed agent chain, returning response from last agent: {}", agent_name);
|
||||
info!(
|
||||
"Completed agent chain, returning response from last agent: {}",
|
||||
agent_name
|
||||
);
|
||||
return response_handler
|
||||
.create_streaming_response(llm_response)
|
||||
.await
|
||||
|
|
@ -341,7 +359,10 @@ async fn handle_agent_chat(
|
|||
}
|
||||
|
||||
// For intermediate agents, collect the full response and pass to next agent
|
||||
debug!("Collecting response from intermediate agent: {}", agent_name);
|
||||
debug!(
|
||||
"Collecting response from intermediate agent: {}",
|
||||
agent_name
|
||||
);
|
||||
let response_text = response_handler.collect_full_response(llm_response).await?;
|
||||
|
||||
info!(
|
||||
|
|
@ -364,7 +385,6 @@ async fn handle_agent_chat(
|
|||
});
|
||||
|
||||
current_messages.push(last_message);
|
||||
|
||||
}
|
||||
|
||||
// This should never be reached since we return in the last agent iteration
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ use std::collections::HashMap;
|
|||
use std::sync::Arc;
|
||||
|
||||
use common::configuration::{
|
||||
Agent, AgentFilterChain, Listener, AgentUsagePreference, OrchestrationPreference,
|
||||
Agent, AgentFilterChain, AgentUsagePreference, Listener, OrchestrationPreference,
|
||||
};
|
||||
use hermesllm::apis::openai::Message;
|
||||
use tracing::{debug, warn};
|
||||
|
|
|
|||
|
|
@ -1,20 +1,18 @@
|
|||
use bytes::Bytes;
|
||||
use eventsource_stream::Eventsource;
|
||||
use futures::StreamExt;
|
||||
use hermesllm::apis::openai::{
|
||||
ChatCompletionsRequest, ChatCompletionsResponse, Choice, FinishReason, FunctionCall, Message,
|
||||
MessageContent, ResponseMessage, Role, Tool, ToolCall, Usage,
|
||||
};
|
||||
use http_body_util::{combinators::BoxBody, BodyExt, Full};
|
||||
use hyper::body::Incoming;
|
||||
use hyper::{Request, Response, StatusCode};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{json, Value};
|
||||
use std::collections::HashMap;
|
||||
use thiserror::Error;
|
||||
use tracing::{info, error};
|
||||
use futures::StreamExt;
|
||||
use bytes::Bytes;
|
||||
use http_body_util::{combinators::BoxBody, BodyExt, Full};
|
||||
use hyper::body::Incoming;
|
||||
use hyper::{Request, Response, StatusCode};
|
||||
use eventsource_stream::Eventsource;
|
||||
|
||||
|
||||
use tracing::{error, info};
|
||||
|
||||
// ============================================================================
|
||||
// CONSTANTS FOR HALLUCINATION DETECTION
|
||||
|
|
@ -273,17 +271,14 @@ impl ArchFunctionHandler {
|
|||
let mut stack: Vec<char> = Vec::new();
|
||||
let mut fixed_str = String::new();
|
||||
|
||||
let matching_bracket: HashMap<char, char> =
|
||||
[(')', '('), ('}', '{'), (']', '[')]
|
||||
.iter()
|
||||
.cloned()
|
||||
.collect();
|
||||
|
||||
let opening_bracket: HashMap<char, char> = matching_bracket
|
||||
let matching_bracket: HashMap<char, char> = [(')', '('), ('}', '{'), (']', '[')]
|
||||
.iter()
|
||||
.map(|(k, v)| (*v, *k))
|
||||
.cloned()
|
||||
.collect();
|
||||
|
||||
let opening_bracket: HashMap<char, char> =
|
||||
matching_bracket.iter().map(|(k, v)| (*v, *k)).collect();
|
||||
|
||||
for ch in json_str.chars() {
|
||||
if ch == '{' || ch == '[' || ch == '(' {
|
||||
stack.push(ch);
|
||||
|
|
@ -332,12 +327,18 @@ impl ArchFunctionHandler {
|
|||
// Remove markdown code blocks
|
||||
let mut content = content.trim().to_string();
|
||||
if content.starts_with("```") && content.ends_with("```") {
|
||||
content = content.trim_start_matches("```").trim_end_matches("```").to_string();
|
||||
content = content
|
||||
.trim_start_matches("```")
|
||||
.trim_end_matches("```")
|
||||
.to_string();
|
||||
if content.starts_with("json") {
|
||||
content = content.trim_start_matches("json").to_string();
|
||||
}
|
||||
// Trim again after removing code blocks to eliminate internal whitespace
|
||||
content = content.trim_start_matches(r"\n").trim_end_matches(r"\n").to_string();
|
||||
content = content
|
||||
.trim_start_matches(r"\n")
|
||||
.trim_end_matches(r"\n")
|
||||
.to_string();
|
||||
content = content.trim().to_string();
|
||||
// Unescape the quotes: \" -> "
|
||||
// The model sometimes returns escaped JSON inside markdown blocks
|
||||
|
|
@ -453,12 +454,12 @@ impl ArchFunctionHandler {
|
|||
/// Helper method to check if a value matches the expected type
|
||||
fn check_value_type(&self, value: &Value, target_type: &str) -> bool {
|
||||
match target_type {
|
||||
"int" | "integer" => value.is_i64() || value.is_u64(),
|
||||
"int" | "integer" => value.is_i64() || value.is_u64(),
|
||||
"float" | "number" => value.is_f64() || value.is_i64() || value.is_u64(),
|
||||
"bool" | "boolean" => value.is_boolean(),
|
||||
"str" | "string" => value.is_string(),
|
||||
"list" | "array" => value.is_array(),
|
||||
"dict" | "object" => value.is_object(),
|
||||
"bool" | "boolean" => value.is_boolean(),
|
||||
"str" | "string" => value.is_string(),
|
||||
"list" | "array" => value.is_array(),
|
||||
"dict" | "object" => value.is_object(),
|
||||
_ => true,
|
||||
}
|
||||
}
|
||||
|
|
@ -505,15 +506,19 @@ impl ArchFunctionHandler {
|
|||
let func_name = &tool_call.function.name;
|
||||
|
||||
// Parse arguments as JSON
|
||||
let func_args: HashMap<String, Value> = match serde_json::from_str(&tool_call.function.arguments) {
|
||||
Ok(args) => args,
|
||||
Err(e) => {
|
||||
verification.is_valid = false;
|
||||
verification.invalid_tool_call = Some(tool_call.clone());
|
||||
verification.error_message = format!("Failed to parse arguments for function '{}': {}", func_name, e);
|
||||
break;
|
||||
}
|
||||
};
|
||||
let func_args: HashMap<String, Value> =
|
||||
match serde_json::from_str(&tool_call.function.arguments) {
|
||||
Ok(args) => args,
|
||||
Err(e) => {
|
||||
verification.is_valid = false;
|
||||
verification.invalid_tool_call = Some(tool_call.clone());
|
||||
verification.error_message = format!(
|
||||
"Failed to parse arguments for function '{}': {}",
|
||||
func_name, e
|
||||
);
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
// Check if function is available
|
||||
if let Some(function_params) = functions.get(func_name) {
|
||||
|
|
@ -541,14 +546,23 @@ impl ArchFunctionHandler {
|
|||
if let Some(properties_obj) = properties.as_object() {
|
||||
for (param_name, param_value) in &func_args {
|
||||
if let Some(param_schema) = properties_obj.get(param_name) {
|
||||
if let Some(target_type) = param_schema.get("type").and_then(|v| v.as_str()) {
|
||||
if self.config.support_data_types.contains(&target_type.to_string()) {
|
||||
if let Some(target_type) =
|
||||
param_schema.get("type").and_then(|v| v.as_str())
|
||||
{
|
||||
if self
|
||||
.config
|
||||
.support_data_types
|
||||
.contains(&target_type.to_string())
|
||||
{
|
||||
// Validate data type using helper method
|
||||
match self.validate_or_convert_parameter(param_value, target_type) {
|
||||
match self
|
||||
.validate_or_convert_parameter(param_value, target_type)
|
||||
{
|
||||
Ok(is_valid) => {
|
||||
if !is_valid {
|
||||
verification.is_valid = false;
|
||||
verification.invalid_tool_call = Some(tool_call.clone());
|
||||
verification.invalid_tool_call =
|
||||
Some(tool_call.clone());
|
||||
verification.error_message = format!(
|
||||
"Parameter `{}` is expected to have the data type `{}`, got incompatible type.",
|
||||
param_name, target_type
|
||||
|
|
@ -558,7 +572,8 @@ impl ArchFunctionHandler {
|
|||
}
|
||||
Err(_) => {
|
||||
verification.is_valid = false;
|
||||
verification.invalid_tool_call = Some(tool_call.clone());
|
||||
verification.invalid_tool_call =
|
||||
Some(tool_call.clone());
|
||||
verification.error_message = format!(
|
||||
"Parameter `{}` is expected to have the data type `{}`, got incompatible type.",
|
||||
param_name, target_type
|
||||
|
|
@ -569,7 +584,10 @@ impl ArchFunctionHandler {
|
|||
} else {
|
||||
verification.is_valid = false;
|
||||
verification.invalid_tool_call = Some(tool_call.clone());
|
||||
verification.error_message = format!("Data type `{}` is not supported.", target_type);
|
||||
verification.error_message = format!(
|
||||
"Data type `{}` is not supported.",
|
||||
target_type
|
||||
);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
|
@ -598,11 +616,8 @@ impl ArchFunctionHandler {
|
|||
/// Formats the system prompt with tools
|
||||
pub fn format_system_prompt(&self, tools: &[Tool]) -> Result<String> {
|
||||
let tools_str = self.convert_tools(tools)?;
|
||||
let system_prompt = self
|
||||
.config
|
||||
.task_prompt
|
||||
.replace("{tools}", &tools_str)
|
||||
+ &self.config.format_prompt;
|
||||
let system_prompt =
|
||||
self.config.task_prompt.replace("{tools}", &tools_str) + &self.config.format_prompt;
|
||||
|
||||
Ok(system_prompt)
|
||||
}
|
||||
|
|
@ -665,15 +680,22 @@ impl ArchFunctionHandler {
|
|||
|
||||
// Strip markdown code blocks
|
||||
if tool_call_msg.starts_with("```") && tool_call_msg.ends_with("```") {
|
||||
tool_call_msg = tool_call_msg.trim_start_matches("```").trim_end_matches("```").trim().to_string();
|
||||
tool_call_msg = tool_call_msg
|
||||
.trim_start_matches("```")
|
||||
.trim_end_matches("```")
|
||||
.trim()
|
||||
.to_string();
|
||||
if tool_call_msg.starts_with("json") {
|
||||
tool_call_msg = tool_call_msg.trim_start_matches("json").trim().to_string();
|
||||
tool_call_msg =
|
||||
tool_call_msg.trim_start_matches("json").trim().to_string();
|
||||
}
|
||||
}
|
||||
|
||||
// Extract function name
|
||||
if let Ok(parsed) = serde_json::from_str::<Value>(&tool_call_msg) {
|
||||
if let Some(tool_calls_arr) = parsed.get("tool_calls").and_then(|v| v.as_array()) {
|
||||
if let Some(tool_calls_arr) =
|
||||
parsed.get("tool_calls").and_then(|v| v.as_array())
|
||||
{
|
||||
if let Some(first_tool_call) = tool_calls_arr.first() {
|
||||
let func_name = first_tool_call
|
||||
.get("name")
|
||||
|
|
@ -685,8 +707,10 @@ impl ArchFunctionHandler {
|
|||
"result": content,
|
||||
});
|
||||
|
||||
content = format!("<tool_response>\n{}\n</tool_response>",
|
||||
serde_json::to_string(&tool_response)?);
|
||||
content = format!(
|
||||
"<tool_response>\n{}\n</tool_response>",
|
||||
serde_json::to_string(&tool_response)?
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -717,7 +741,7 @@ impl ArchFunctionHandler {
|
|||
if let Some(instruction) = extra_instruction {
|
||||
if let Some(last) = processed_messages.last_mut() {
|
||||
if let MessageContent::Text(content) = &mut last.content {
|
||||
content.push_str("\n");
|
||||
content.push('\n');
|
||||
content.push_str(instruction);
|
||||
}
|
||||
}
|
||||
|
|
@ -750,13 +774,11 @@ impl ArchFunctionHandler {
|
|||
for i in (conversation_idx..messages.len()).rev() {
|
||||
if let MessageContent::Text(content) = &messages[i].content {
|
||||
num_tokens += content.len() / 4;
|
||||
if num_tokens >= max_tokens {
|
||||
if messages[i].role == Role::User {
|
||||
// Set message_idx to current position and break
|
||||
// This matches Python's behavior where message_idx is set before break
|
||||
message_idx = i;
|
||||
break;
|
||||
}
|
||||
if num_tokens >= max_tokens && messages[i].role == Role::User {
|
||||
// Set message_idx to current position and break
|
||||
// This matches Python's behavior where message_idx is set before break
|
||||
message_idx = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
// Only update message_idx if we haven't hit the token limit yet
|
||||
|
|
@ -789,7 +811,11 @@ impl ArchFunctionHandler {
|
|||
}
|
||||
|
||||
/// Helper to create a request with VLLM-specific parameters
|
||||
fn create_request_with_extra_body(&self, messages: Vec<Message>, stream: bool) -> ChatCompletionsRequest {
|
||||
fn create_request_with_extra_body(
|
||||
&self,
|
||||
messages: Vec<Message>,
|
||||
stream: bool,
|
||||
) -> ChatCompletionsRequest {
|
||||
ChatCompletionsRequest {
|
||||
model: self.model_name.clone(),
|
||||
messages,
|
||||
|
|
@ -813,24 +839,38 @@ impl ArchFunctionHandler {
|
|||
}
|
||||
|
||||
/// Makes a streaming request and returns the SSE event stream
|
||||
async fn make_streaming_request(&self, request: ChatCompletionsRequest) -> Result<std::pin::Pin<Box<dyn futures::Stream<Item = std::result::Result<Value, String>> + Send>>> {
|
||||
let request_body = serde_json::to_string(&request)
|
||||
.map_err(|e| FunctionCallingError::InvalidModelResponse(format!("Failed to serialize request: {}", e)))?;
|
||||
async fn make_streaming_request(
|
||||
&self,
|
||||
request: ChatCompletionsRequest,
|
||||
) -> Result<
|
||||
std::pin::Pin<Box<dyn futures::Stream<Item = std::result::Result<Value, String>> + Send>>,
|
||||
> {
|
||||
let request_body = serde_json::to_string(&request).map_err(|e| {
|
||||
FunctionCallingError::InvalidModelResponse(format!(
|
||||
"Failed to serialize request: {}",
|
||||
e
|
||||
))
|
||||
})?;
|
||||
|
||||
let response = self.http_client
|
||||
let response = self
|
||||
.http_client
|
||||
.post(&self.endpoint_url)
|
||||
.header("Content-Type", "application/json")
|
||||
.body(request_body)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| FunctionCallingError::HttpError(e))?;
|
||||
.map_err(FunctionCallingError::HttpError)?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let error_text = response.text().await.unwrap_or_else(|_| "Unknown error".to_string());
|
||||
return Err(FunctionCallingError::InvalidModelResponse(
|
||||
format!("HTTP error {}: {}", status, error_text)
|
||||
));
|
||||
let error_text = response
|
||||
.text()
|
||||
.await
|
||||
.unwrap_or_else(|_| "Unknown error".to_string());
|
||||
return Err(FunctionCallingError::InvalidModelResponse(format!(
|
||||
"HTTP error {}: {}",
|
||||
status, error_text
|
||||
)));
|
||||
}
|
||||
|
||||
// Parse SSE stream
|
||||
|
|
@ -856,38 +896,51 @@ impl ArchFunctionHandler {
|
|||
}
|
||||
|
||||
/// Makes a non-streaming request and returns the response
|
||||
async fn make_non_streaming_request(&self, request: ChatCompletionsRequest) -> Result<ChatCompletionsResponse> {
|
||||
let request_body = serde_json::to_string(&request)
|
||||
.map_err(|e| FunctionCallingError::InvalidModelResponse(format!("Failed to serialize request: {}", e)))?;
|
||||
async fn make_non_streaming_request(
|
||||
&self,
|
||||
request: ChatCompletionsRequest,
|
||||
) -> Result<ChatCompletionsResponse> {
|
||||
let request_body = serde_json::to_string(&request).map_err(|e| {
|
||||
FunctionCallingError::InvalidModelResponse(format!(
|
||||
"Failed to serialize request: {}",
|
||||
e
|
||||
))
|
||||
})?;
|
||||
|
||||
let response = self.http_client
|
||||
let response = self
|
||||
.http_client
|
||||
.post(&self.endpoint_url)
|
||||
.header("Content-Type", "application/json")
|
||||
.body(request_body)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| FunctionCallingError::HttpError(e))?;
|
||||
.map_err(FunctionCallingError::HttpError)?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let error_text = response.text().await.unwrap_or_else(|_| "Unknown error".to_string());
|
||||
return Err(FunctionCallingError::InvalidModelResponse(
|
||||
format!("HTTP error {}: {}", status, error_text)
|
||||
));
|
||||
let error_text = response
|
||||
.text()
|
||||
.await
|
||||
.unwrap_or_else(|_| "Unknown error".to_string());
|
||||
return Err(FunctionCallingError::InvalidModelResponse(format!(
|
||||
"HTTP error {}: {}",
|
||||
status, error_text
|
||||
)));
|
||||
}
|
||||
|
||||
let response_text = response.text().await
|
||||
.map_err(|e| FunctionCallingError::HttpError(e))?;
|
||||
let response_text = response
|
||||
.text()
|
||||
.await
|
||||
.map_err(FunctionCallingError::HttpError)?;
|
||||
|
||||
serde_json::from_str(&response_text)
|
||||
.map_err(|e| FunctionCallingError::JsonParseError(e))
|
||||
serde_json::from_str(&response_text).map_err(FunctionCallingError::JsonParseError)
|
||||
}
|
||||
|
||||
pub async fn function_calling_chat(
|
||||
&self,
|
||||
request: ChatCompletionsRequest,
|
||||
) -> Result<ChatCompletionsResponse> {
|
||||
use tracing::{info, error};
|
||||
use tracing::{error, info};
|
||||
|
||||
info!("[Arch-Function] - ChatCompletion");
|
||||
|
||||
|
|
@ -899,10 +952,14 @@ impl ArchFunctionHandler {
|
|||
request.metadata.as_ref(),
|
||||
)?;
|
||||
|
||||
info!("[request to arch-fc]: model: {}, messages count: {}",
|
||||
self.model_name, messages.len());
|
||||
info!(
|
||||
"[request to arch-fc]: model: {}, messages count: {}",
|
||||
self.model_name,
|
||||
messages.len()
|
||||
);
|
||||
|
||||
let use_agent_orchestrator = request.metadata
|
||||
let use_agent_orchestrator = request
|
||||
.metadata
|
||||
.as_ref()
|
||||
.and_then(|m| m.get("use_agent_orchestrator"))
|
||||
.and_then(|v| v.as_bool())
|
||||
|
|
@ -918,89 +975,95 @@ impl ArchFunctionHandler {
|
|||
|
||||
if use_agent_orchestrator {
|
||||
while let Some(chunk_result) = stream.next().await {
|
||||
let chunk = chunk_result.map_err(|e| FunctionCallingError::InvalidModelResponse(e))?;
|
||||
let chunk = chunk_result.map_err(FunctionCallingError::InvalidModelResponse)?;
|
||||
// Extract content from JSON response
|
||||
if let Some(choices) = chunk.get("choices").and_then(|v| v.as_array()) {
|
||||
if let Some(choice) = choices.first() {
|
||||
if let Some(content) = choice.get("delta")
|
||||
if let Some(content) = choice
|
||||
.get("delta")
|
||||
.and_then(|d| d.get("content"))
|
||||
.and_then(|c| c.as_str()) {
|
||||
.and_then(|c| c.as_str())
|
||||
{
|
||||
model_response.push_str(content);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
info!("[Agent Orchestrator]: response received");
|
||||
} else {
|
||||
if let Some(tools) = request.tools.as_ref() {
|
||||
let mut hallucination_state = HallucinationState::new(tools);
|
||||
let mut has_tool_calls = None;
|
||||
let mut has_hallucination = false;
|
||||
} else if let Some(tools) = request.tools.as_ref() {
|
||||
let mut hallucination_state = HallucinationState::new(tools);
|
||||
let mut has_tool_calls = None;
|
||||
let mut has_hallucination = false;
|
||||
|
||||
while let Some(chunk_result) = stream.next().await {
|
||||
let chunk = chunk_result.map_err(|e| FunctionCallingError::InvalidModelResponse(e))?;
|
||||
while let Some(chunk_result) = stream.next().await {
|
||||
let chunk = chunk_result.map_err(FunctionCallingError::InvalidModelResponse)?;
|
||||
|
||||
// Extract content and logprobs from JSON response
|
||||
if let Some(choices) = chunk.get("choices").and_then(|v| v.as_array()) {
|
||||
if let Some(choice) = choices.first() {
|
||||
if let Some(content) = choice.get("delta")
|
||||
.and_then(|d| d.get("content"))
|
||||
.and_then(|c| c.as_str()) {
|
||||
// Extract content and logprobs from JSON response
|
||||
if let Some(choices) = chunk.get("choices").and_then(|v| v.as_array()) {
|
||||
if let Some(choice) = choices.first() {
|
||||
if let Some(content) = choice
|
||||
.get("delta")
|
||||
.and_then(|d| d.get("content"))
|
||||
.and_then(|c| c.as_str())
|
||||
{
|
||||
// Extract logprobs
|
||||
let logprobs: Vec<f64> = choice
|
||||
.get("logprobs")
|
||||
.and_then(|lp| lp.get("content"))
|
||||
.and_then(|c| c.as_array())
|
||||
.and_then(|arr| arr.first())
|
||||
.and_then(|token| token.get("top_logprobs"))
|
||||
.and_then(|tlp| tlp.as_array())
|
||||
.map(|arr| {
|
||||
arr.iter()
|
||||
.filter_map(|v| v.get("logprob").and_then(|lp| lp.as_f64()))
|
||||
.collect()
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
// Extract logprobs
|
||||
let logprobs: Vec<f64> = choice.get("logprobs")
|
||||
.and_then(|lp| lp.get("content"))
|
||||
.and_then(|c| c.as_array())
|
||||
.and_then(|arr| arr.first())
|
||||
.and_then(|token| token.get("top_logprobs"))
|
||||
.and_then(|tlp| tlp.as_array())
|
||||
.map(|arr| {
|
||||
arr.iter()
|
||||
.filter_map(|v| v.get("logprob").and_then(|lp| lp.as_f64()))
|
||||
.collect()
|
||||
})
|
||||
.unwrap_or_default();
|
||||
if hallucination_state
|
||||
.append_and_check_token_hallucination(content.to_string(), logprobs)
|
||||
{
|
||||
has_hallucination = true;
|
||||
break;
|
||||
}
|
||||
|
||||
if hallucination_state.append_and_check_token_hallucination(content.to_string(), logprobs) {
|
||||
has_hallucination = true;
|
||||
break;
|
||||
}
|
||||
|
||||
if hallucination_state.tokens.len() > 5 && has_tool_calls.is_none() {
|
||||
let collected_content = hallucination_state.tokens.join("");
|
||||
has_tool_calls = Some(collected_content.contains("tool_calls"));
|
||||
}
|
||||
if hallucination_state.tokens.len() > 5 && has_tool_calls.is_none() {
|
||||
let collected_content = hallucination_state.tokens.join("");
|
||||
has_tool_calls = Some(collected_content.contains("tool_calls"));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if has_tool_calls == Some(true) && has_hallucination {
|
||||
info!("[Hallucination]: {}", hallucination_state.error_message);
|
||||
if has_tool_calls == Some(true) && has_hallucination {
|
||||
info!("[Hallucination]: {}", hallucination_state.error_message);
|
||||
|
||||
let clarify_messages = self.prefill_message(messages.clone(), &self.clarify_prefix);
|
||||
let clarify_request = self.create_request_with_extra_body(clarify_messages, false);
|
||||
let clarify_messages = self.prefill_message(messages.clone(), &self.clarify_prefix);
|
||||
let clarify_request = self.create_request_with_extra_body(clarify_messages, false);
|
||||
|
||||
let retry_response = self.make_non_streaming_request(clarify_request).await?;
|
||||
let retry_response = self.make_non_streaming_request(clarify_request).await?;
|
||||
|
||||
if let Some(choice) = retry_response.choices.first() {
|
||||
if let Some(content) = &choice.message.content {
|
||||
model_response = content.clone();
|
||||
}
|
||||
if let Some(choice) = retry_response.choices.first() {
|
||||
if let Some(content) = &choice.message.content {
|
||||
model_response = content.clone();
|
||||
}
|
||||
} else {
|
||||
model_response = hallucination_state.tokens.join("");
|
||||
}
|
||||
} else {
|
||||
while let Some(chunk_result) = stream.next().await {
|
||||
let chunk = chunk_result.map_err(|e| FunctionCallingError::InvalidModelResponse(e))?;
|
||||
if let Some(choices) = chunk.get("choices").and_then(|v| v.as_array()) {
|
||||
if let Some(choice) = choices.first() {
|
||||
if let Some(content) = choice.get("delta")
|
||||
.and_then(|d| d.get("content"))
|
||||
.and_then(|c| c.as_str()) {
|
||||
model_response.push_str(content);
|
||||
}
|
||||
model_response = hallucination_state.tokens.join("");
|
||||
}
|
||||
} else {
|
||||
while let Some(chunk_result) = stream.next().await {
|
||||
let chunk = chunk_result.map_err(FunctionCallingError::InvalidModelResponse)?;
|
||||
if let Some(choices) = chunk.get("choices").and_then(|v| v.as_array()) {
|
||||
if let Some(choice) = choices.first() {
|
||||
if let Some(content) = choice
|
||||
.get("delta")
|
||||
.and_then(|d| d.get("content"))
|
||||
.and_then(|c| c.as_str())
|
||||
{
|
||||
model_response.push_str(content);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1009,10 +1072,17 @@ impl ArchFunctionHandler {
|
|||
|
||||
let response_dict = self.parse_model_response(&model_response);
|
||||
|
||||
info!("[arch-fc]: raw model response: {}", response_dict.raw_response);
|
||||
info!(
|
||||
"[arch-fc]: raw model response: {}",
|
||||
response_dict.raw_response
|
||||
);
|
||||
|
||||
// General model response (no intent matched - should route to default target)
|
||||
let model_message = if response_dict.response.as_ref().map_or(false, |s| !s.is_empty()) {
|
||||
let model_message = if response_dict
|
||||
.response
|
||||
.as_ref()
|
||||
.is_some_and(|s| !s.is_empty())
|
||||
{
|
||||
// When arch-fc returns a "response" field, it means no intent was matched
|
||||
// Return empty content and empty tool_calls so prompt_gateway routes to default target
|
||||
ResponseMessage {
|
||||
|
|
@ -1053,8 +1123,11 @@ impl ArchFunctionHandler {
|
|||
let verification = self.verify_tool_calls(tools, &response_dict.tool_calls);
|
||||
|
||||
if verification.is_valid {
|
||||
info!("[Tool calls]: {:?}",
|
||||
response_dict.tool_calls.iter()
|
||||
info!(
|
||||
"[Tool calls]: {:?}",
|
||||
response_dict
|
||||
.tool_calls
|
||||
.iter()
|
||||
.map(|tc| &tc.function)
|
||||
.collect::<Vec<_>>()
|
||||
);
|
||||
|
|
@ -1092,8 +1165,11 @@ impl ArchFunctionHandler {
|
|||
}
|
||||
}
|
||||
} else {
|
||||
info!("[Tool calls]: {:?}",
|
||||
response_dict.tool_calls.iter()
|
||||
info!(
|
||||
"[Tool calls]: {:?}",
|
||||
response_dict
|
||||
.tool_calls
|
||||
.iter()
|
||||
.map(|tc| &tc.function)
|
||||
.collect::<Vec<_>>()
|
||||
);
|
||||
|
|
@ -1108,7 +1184,10 @@ impl ArchFunctionHandler {
|
|||
}
|
||||
}
|
||||
} else {
|
||||
error!("Invalid tool calls in response: {}", response_dict.error_message);
|
||||
error!(
|
||||
"Invalid tool calls in response: {}",
|
||||
response_dict.error_message
|
||||
);
|
||||
ResponseMessage {
|
||||
role: Role::Assistant,
|
||||
content: Some(String::new()),
|
||||
|
|
@ -1243,7 +1322,6 @@ pub async fn function_calling_chat_handler(
|
|||
req: Request<Incoming>,
|
||||
llm_provider_url: String,
|
||||
) -> std::result::Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||
|
||||
use hermesllm::apis::openai::ChatCompletionsRequest;
|
||||
let whole_body = req.collect().await?.to_bytes();
|
||||
|
||||
|
|
@ -1255,10 +1333,13 @@ pub async fn function_calling_chat_handler(
|
|||
let mut response = Response::new(full(
|
||||
serde_json::json!({
|
||||
"error": format!("Invalid request body: {}", e)
|
||||
}).to_string()
|
||||
})
|
||||
.to_string(),
|
||||
));
|
||||
*response.status_mut() = StatusCode::BAD_REQUEST;
|
||||
response.headers_mut().insert("Content-Type", "application/json".parse().unwrap());
|
||||
response
|
||||
.headers_mut()
|
||||
.insert("Content-Type", "application/json".parse().unwrap());
|
||||
return Ok(response);
|
||||
}
|
||||
};
|
||||
|
|
@ -1271,24 +1352,31 @@ pub async fn function_calling_chat_handler(
|
|||
// Parse as ChatCompletionsRequest
|
||||
let chat_request: ChatCompletionsRequest = match serde_json::from_value(body_json) {
|
||||
Ok(req) => {
|
||||
info!("[request body]: {}", serde_json::to_string(&req).unwrap_or_default());
|
||||
info!(
|
||||
"[request body]: {}",
|
||||
serde_json::to_string(&req).unwrap_or_default()
|
||||
);
|
||||
req
|
||||
},
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Failed to parse request body: {}", e);
|
||||
let mut response = Response::new(full(
|
||||
serde_json::json!({
|
||||
"error": format!("Invalid request body: {}", e)
|
||||
}).to_string()
|
||||
})
|
||||
.to_string(),
|
||||
));
|
||||
*response.status_mut() = StatusCode::BAD_REQUEST;
|
||||
response.headers_mut().insert("Content-Type", "application/json".parse().unwrap());
|
||||
response
|
||||
.headers_mut()
|
||||
.insert("Content-Type", "application/json".parse().unwrap());
|
||||
return Ok(response);
|
||||
}
|
||||
};
|
||||
|
||||
// Determine which handler to use based on metadata
|
||||
let use_agent_orchestrator = chat_request.metadata
|
||||
let use_agent_orchestrator = chat_request
|
||||
.metadata
|
||||
.as_ref()
|
||||
.and_then(|m| m.get("use_agent_orchestrator"))
|
||||
.and_then(|v| v.as_bool())
|
||||
|
|
@ -1309,7 +1397,10 @@ pub async fn function_calling_chat_handler(
|
|||
ARCH_FUNCTION_MODEL_NAME.to_string(),
|
||||
llm_provider_url.clone(),
|
||||
);
|
||||
handler.function_handler.function_calling_chat(chat_request).await
|
||||
handler
|
||||
.function_handler
|
||||
.function_calling_chat(chat_request)
|
||||
.await
|
||||
} else {
|
||||
let handler = ArchFunctionHandler::new(
|
||||
ARCH_FUNCTION_MODEL_NAME.to_string(),
|
||||
|
|
@ -1328,7 +1419,9 @@ pub async fn function_calling_chat_handler(
|
|||
|
||||
let mut response = Response::new(full(response_json));
|
||||
*response.status_mut() = StatusCode::OK;
|
||||
response.headers_mut().insert("Content-Type", "application/json".parse().unwrap());
|
||||
response
|
||||
.headers_mut()
|
||||
.insert("Content-Type", "application/json".parse().unwrap());
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
|
|
@ -1341,13 +1434,14 @@ pub async fn function_calling_chat_handler(
|
|||
|
||||
let mut response = Response::new(full(error_response.to_string()));
|
||||
*response.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
|
||||
response.headers_mut().insert("Content-Type", "application/json".parse().unwrap());
|
||||
response
|
||||
.headers_mut()
|
||||
.insert("Content-Type", "application/json".parse().unwrap());
|
||||
Ok(response)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// ============================================================================
|
||||
// TESTS
|
||||
// ============================================================================
|
||||
|
|
@ -1370,10 +1464,13 @@ mod tests {
|
|||
assert!(config.task_prompt.contains("</tools>\\n\\n"));
|
||||
|
||||
// Format prompt should contain literal escaped newlines and proper JSON examples
|
||||
assert!(config.format_prompt.contains("\\n\\nBased on your analysis"));
|
||||
assert!(config.format_prompt.contains(r#"{\"response\": \"Your response text here\"}"#));
|
||||
assert!(config
|
||||
.format_prompt
|
||||
.contains("\\n\\nBased on your analysis"));
|
||||
assert!(config
|
||||
.format_prompt
|
||||
.contains(r#"{\"response\": \"Your response text here\"}"#));
|
||||
assert!(config.format_prompt.contains(r#"{\"tool_calls\": [{"#));
|
||||
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -1384,7 +1481,11 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn test_fix_json_string_valid() {
|
||||
let handler = ArchFunctionHandler::new("test-model".to_string(), ArchFunctionConfig::default(), "http://localhost:8000".to_string());
|
||||
let handler = ArchFunctionHandler::new(
|
||||
"test-model".to_string(),
|
||||
ArchFunctionConfig::default(),
|
||||
"http://localhost:8000".to_string(),
|
||||
);
|
||||
let json_str = r#"{"name": "test", "value": 123}"#;
|
||||
let result = handler.fix_json_string(json_str);
|
||||
assert!(result.is_ok());
|
||||
|
|
@ -1392,7 +1493,11 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn test_fix_json_string_missing_bracket() {
|
||||
let handler = ArchFunctionHandler::new("test-model".to_string(), ArchFunctionConfig::default(), "http://localhost:8000".to_string());
|
||||
let handler = ArchFunctionHandler::new(
|
||||
"test-model".to_string(),
|
||||
ArchFunctionConfig::default(),
|
||||
"http://localhost:8000".to_string(),
|
||||
);
|
||||
let json_str = r#"{"name": "test", "value": 123"#;
|
||||
let result = handler.fix_json_string(json_str);
|
||||
assert!(result.is_ok());
|
||||
|
|
@ -1402,8 +1507,13 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn test_parse_model_response_with_tool_calls() {
|
||||
let handler = ArchFunctionHandler::new("test-model".to_string(), ArchFunctionConfig::default(), "http://localhost:8000".to_string());
|
||||
let content = r#"{"tool_calls": [{"name": "get_weather", "arguments": {"location": "NYC"}}]}"#;
|
||||
let handler = ArchFunctionHandler::new(
|
||||
"test-model".to_string(),
|
||||
ArchFunctionConfig::default(),
|
||||
"http://localhost:8000".to_string(),
|
||||
);
|
||||
let content =
|
||||
r#"{"tool_calls": [{"name": "get_weather", "arguments": {"location": "NYC"}}]}"#;
|
||||
let result = handler.parse_model_response(content);
|
||||
|
||||
assert!(result.is_valid);
|
||||
|
|
@ -1413,8 +1523,13 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn test_parse_model_response_with_clarification() {
|
||||
let handler = ArchFunctionHandler::new("test-model".to_string(), ArchFunctionConfig::default(), "http://localhost:8000".to_string());
|
||||
let content = r#"{"required_functions": ["get_weather"], "clarification": "What location?"}"#;
|
||||
let handler = ArchFunctionHandler::new(
|
||||
"test-model".to_string(),
|
||||
ArchFunctionConfig::default(),
|
||||
"http://localhost:8000".to_string(),
|
||||
);
|
||||
let content =
|
||||
r#"{"required_functions": ["get_weather"], "clarification": "What location?"}"#;
|
||||
let result = handler.parse_model_response(content);
|
||||
|
||||
assert!(result.is_valid);
|
||||
|
|
@ -1424,7 +1539,11 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn test_convert_data_type_int_to_float() {
|
||||
let handler = ArchFunctionHandler::new("test-model".to_string(), ArchFunctionConfig::default(), "http://localhost:8000".to_string());
|
||||
let handler = ArchFunctionHandler::new(
|
||||
"test-model".to_string(),
|
||||
ArchFunctionConfig::default(),
|
||||
"http://localhost:8000".to_string(),
|
||||
);
|
||||
let value = json!(42);
|
||||
let result = handler.convert_data_type(&value, "float");
|
||||
assert!(result.is_ok());
|
||||
|
|
@ -1504,13 +1623,12 @@ pub fn check_threshold(
|
|||
}
|
||||
|
||||
/// Checks if a parameter is required in the function description
|
||||
pub fn is_parameter_required(
|
||||
function_description: &Value,
|
||||
parameter_name: &str,
|
||||
) -> bool {
|
||||
pub fn is_parameter_required(function_description: &Value, parameter_name: &str) -> bool {
|
||||
if let Some(required) = function_description.get("required") {
|
||||
if let Some(required_arr) = required.as_array() {
|
||||
return required_arr.iter().any(|v| v.as_str() == Some(parameter_name));
|
||||
return required_arr
|
||||
.iter()
|
||||
.any(|v| v.as_str() == Some(parameter_name));
|
||||
}
|
||||
}
|
||||
false
|
||||
|
|
@ -1559,12 +1677,7 @@ impl HallucinationState {
|
|||
pub fn new(functions: &[Tool]) -> Self {
|
||||
let function_properties: HashMap<String, Value> = functions
|
||||
.iter()
|
||||
.map(|tool| {
|
||||
(
|
||||
tool.function.name.clone(),
|
||||
tool.function.parameters.clone(),
|
||||
)
|
||||
})
|
||||
.map(|tool| (tool.function.name.clone(), tool.function.parameters.clone()))
|
||||
.collect();
|
||||
|
||||
Self {
|
||||
|
|
@ -1620,7 +1733,10 @@ impl HallucinationState {
|
|||
|
||||
// Function name extraction logic
|
||||
if self.state.as_deref() == Some("function_name") {
|
||||
if !FUNC_NAME_END_TOKEN.iter().any(|&t| self.tokens.last().map_or(false, |tok| tok == t)) {
|
||||
if !FUNC_NAME_END_TOKEN
|
||||
.iter()
|
||||
.any(|&t| self.tokens.last().is_some_and(|tok| tok == t))
|
||||
{
|
||||
self.mask.push(MaskToken::FunctionName);
|
||||
} else {
|
||||
self.state = None;
|
||||
|
|
@ -1629,34 +1745,51 @@ impl HallucinationState {
|
|||
}
|
||||
|
||||
// Check for function name start
|
||||
if FUNC_NAME_START_PATTERN.iter().any(|&p| content.ends_with(p)) {
|
||||
if FUNC_NAME_START_PATTERN
|
||||
.iter()
|
||||
.any(|&p| content.ends_with(p))
|
||||
{
|
||||
self.state = Some("function_name".to_string());
|
||||
}
|
||||
|
||||
// Parameter name extraction logic
|
||||
if self.state.as_deref() == Some("parameter_name")
|
||||
&& !PARAMETER_NAME_END_TOKENS.iter().any(|&t| content.ends_with(t)) {
|
||||
&& !PARAMETER_NAME_END_TOKENS
|
||||
.iter()
|
||||
.any(|&t| content.ends_with(t))
|
||||
{
|
||||
self.mask.push(MaskToken::ParameterName);
|
||||
} else if self.state.as_deref() == Some("parameter_name")
|
||||
&& PARAMETER_NAME_END_TOKENS.iter().any(|&t| content.ends_with(t)) {
|
||||
&& PARAMETER_NAME_END_TOKENS
|
||||
.iter()
|
||||
.any(|&t| content.ends_with(t))
|
||||
{
|
||||
self.state = None;
|
||||
self.parameter_name_done = true;
|
||||
self.get_parameter_name();
|
||||
} else if self.parameter_name_done
|
||||
&& !self.open_bracket
|
||||
&& PARAMETER_NAME_START_PATTERN.iter().any(|&p| content.ends_with(p)) {
|
||||
&& PARAMETER_NAME_START_PATTERN
|
||||
.iter()
|
||||
.any(|&p| content.ends_with(p))
|
||||
{
|
||||
self.state = Some("parameter_name".to_string());
|
||||
}
|
||||
|
||||
// First parameter value start
|
||||
if FIRST_PARAM_NAME_START_PATTERN.iter().any(|&p| content.ends_with(p)) {
|
||||
if FIRST_PARAM_NAME_START_PATTERN
|
||||
.iter()
|
||||
.any(|&p| content.ends_with(p))
|
||||
{
|
||||
self.state = Some("parameter_name".to_string());
|
||||
}
|
||||
|
||||
// Parameter value extraction logic
|
||||
if self.state.as_deref() == Some("parameter_value")
|
||||
&& !PARAMETER_VALUE_END_TOKEN.iter().any(|&t| content.ends_with(t)) {
|
||||
|
||||
&& !PARAMETER_VALUE_END_TOKEN
|
||||
.iter()
|
||||
.any(|&t| content.ends_with(t))
|
||||
{
|
||||
// Check for brackets
|
||||
if let Some(last_token) = self.tokens.last() {
|
||||
let open_brackets: Vec<char> = last_token
|
||||
|
|
@ -1694,8 +1827,11 @@ impl HallucinationState {
|
|||
&& self.mask[self.mask.len() - 2] != MaskToken::ParameterValue
|
||||
&& !self.parameter_name.is_empty()
|
||||
{
|
||||
let last_param = self.parameter_name[self.parameter_name.len() - 1].clone();
|
||||
if let Some(func_props) = self.function_properties.get(&self.function_name) {
|
||||
let last_param =
|
||||
self.parameter_name[self.parameter_name.len() - 1].clone();
|
||||
if let Some(func_props) =
|
||||
self.function_properties.get(&self.function_name)
|
||||
{
|
||||
if is_parameter_required(func_props, &last_param)
|
||||
&& !is_parameter_property(func_props, &last_param, "enum")
|
||||
&& !self.check_parameter_name.contains_key(&last_param)
|
||||
|
|
@ -1718,10 +1854,16 @@ impl HallucinationState {
|
|||
}
|
||||
} else if self.state.as_deref() == Some("parameter_value")
|
||||
&& !self.open_bracket
|
||||
&& PARAMETER_VALUE_END_TOKEN.iter().any(|&t| content.ends_with(t)) {
|
||||
&& PARAMETER_VALUE_END_TOKEN
|
||||
.iter()
|
||||
.any(|&t| content.ends_with(t))
|
||||
{
|
||||
self.state = None;
|
||||
} else if self.parameter_name_done
|
||||
&& PARAMETER_VALUE_START_PATTERN.iter().any(|&p| content.ends_with(p)) {
|
||||
&& PARAMETER_VALUE_START_PATTERN
|
||||
.iter()
|
||||
.any(|&p| content.ends_with(p))
|
||||
{
|
||||
self.state = Some("parameter_value".to_string());
|
||||
}
|
||||
|
||||
|
|
@ -1848,18 +1990,18 @@ mod hallucination_tests {
|
|||
let handler = ArchFunctionHandler::new(
|
||||
"test-model".to_string(),
|
||||
ArchFunctionConfig::default(),
|
||||
"http://localhost:8000".to_string()
|
||||
"http://localhost:8000".to_string(),
|
||||
);
|
||||
|
||||
// Test integer types
|
||||
assert!(handler.check_value_type(&json!(42), "integer"));
|
||||
assert!(handler.check_value_type(&json!(42), "int"));
|
||||
assert!(!handler.check_value_type(&json!(3.14), "integer"));
|
||||
assert!(!handler.check_value_type(&json!(3.15), "integer"));
|
||||
|
||||
// Test number types (accepts both int and float)
|
||||
assert!(handler.check_value_type(&json!(3.14), "number"));
|
||||
assert!(handler.check_value_type(&json!(3.15), "number"));
|
||||
assert!(handler.check_value_type(&json!(42), "number"));
|
||||
assert!(handler.check_value_type(&json!(3.14), "float"));
|
||||
assert!(handler.check_value_type(&json!(3.15), "float"));
|
||||
|
||||
// Test boolean
|
||||
assert!(handler.check_value_type(&json!(true), "boolean"));
|
||||
|
|
@ -1890,12 +2032,16 @@ mod hallucination_tests {
|
|||
let handler = ArchFunctionHandler::new(
|
||||
"test-model".to_string(),
|
||||
ArchFunctionConfig::default(),
|
||||
"http://localhost:8000".to_string()
|
||||
"http://localhost:8000".to_string(),
|
||||
);
|
||||
|
||||
// Test valid type - no conversion needed
|
||||
assert!(handler.validate_or_convert_parameter(&json!(42), "integer").unwrap());
|
||||
assert!(handler.validate_or_convert_parameter(&json!("hello"), "string").unwrap());
|
||||
assert!(handler
|
||||
.validate_or_convert_parameter(&json!(42), "integer")
|
||||
.unwrap());
|
||||
assert!(handler
|
||||
.validate_or_convert_parameter(&json!("hello"), "string")
|
||||
.unwrap());
|
||||
|
||||
// Test integer to float conversion (convert_data_type supports this)
|
||||
let result = handler.validate_or_convert_parameter(&json!(42), "float");
|
||||
|
|
@ -1910,8 +2056,12 @@ mod hallucination_tests {
|
|||
assert!(!result.unwrap());
|
||||
|
||||
// Test number accepting both int and float
|
||||
assert!(handler.validate_or_convert_parameter(&json!(42), "number").unwrap());
|
||||
assert!(handler.validate_or_convert_parameter(&json!(3.14), "number").unwrap());
|
||||
assert!(handler
|
||||
.validate_or_convert_parameter(&json!(42), "number")
|
||||
.unwrap());
|
||||
assert!(handler
|
||||
.validate_or_convert_parameter(&json!(3.15), "number")
|
||||
.unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ use crate::router::plano_orchestrator::OrchestratorService;
|
|||
/// 2. PipelineProcessor - executes the agent pipeline
|
||||
/// 3. ResponseHandler - handles response streaming
|
||||
#[cfg(test)]
|
||||
mod integration_tests {
|
||||
mod tests {
|
||||
use super::*;
|
||||
use common::configuration::{Agent, AgentFilterChain, Listener};
|
||||
|
||||
|
|
@ -62,7 +62,10 @@ mod integration_tests {
|
|||
|
||||
let agent_pipeline = AgentFilterChain {
|
||||
id: "terminal-agent".to_string(),
|
||||
filter_chain: Some(vec!["filter-agent".to_string(), "terminal-agent".to_string()]),
|
||||
filter_chain: Some(vec![
|
||||
"filter-agent".to_string(),
|
||||
"terminal-agent".to_string(),
|
||||
]),
|
||||
description: Some("Test pipeline".to_string()),
|
||||
default: Some(true),
|
||||
};
|
||||
|
|
|
|||
|
|
@ -2,48 +2,48 @@ use serde::{Deserialize, Serialize};
|
|||
use std::collections::HashMap;
|
||||
|
||||
pub const JSON_RPC_VERSION: &str = "2.0";
|
||||
pub const TOOL_CALL_METHOD : &str = "tools/call";
|
||||
pub const TOOL_CALL_METHOD: &str = "tools/call";
|
||||
pub const MCP_INITIALIZE: &str = "initialize";
|
||||
pub const MCP_INITIALIZE_NOTIFICATION: &str = "notifications/initialized";
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum JsonRpcId {
|
||||
String(String),
|
||||
Number(u64),
|
||||
String(String),
|
||||
Number(u64),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct JsonRpcRequest {
|
||||
pub jsonrpc: String,
|
||||
pub id: JsonRpcId,
|
||||
pub method: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub params: Option<HashMap<String, serde_json::Value>>,
|
||||
pub jsonrpc: String,
|
||||
pub id: JsonRpcId,
|
||||
pub method: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub params: Option<HashMap<String, serde_json::Value>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct JsonRpcNotification {
|
||||
pub jsonrpc: String,
|
||||
pub method: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub params: Option<HashMap<String, serde_json::Value>>,
|
||||
pub jsonrpc: String,
|
||||
pub method: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub params: Option<HashMap<String, serde_json::Value>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct JsonRpcError {
|
||||
pub code: i32,
|
||||
pub message: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub data: Option<serde_json::Value>,
|
||||
pub code: i32,
|
||||
pub message: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub data: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct JsonRpcResponse {
|
||||
pub jsonrpc: String,
|
||||
pub id: JsonRpcId,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub result: Option<HashMap<String, serde_json::Value>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub error: Option<JsonRpcError>,
|
||||
pub jsonrpc: String,
|
||||
pub id: JsonRpcId,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub result: Option<HashMap<String, serde_json::Value>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub error: Option<JsonRpcError>,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
use bytes::Bytes;
|
||||
use common::configuration::{LlmProvider, ModelAlias};
|
||||
use common::consts::{ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER, REQUEST_ID_HEADER, TRACE_PARENT_HEADER};
|
||||
use common::consts::{
|
||||
ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER, REQUEST_ID_HEADER, TRACE_PARENT_HEADER,
|
||||
};
|
||||
use common::traces::TraceCollector;
|
||||
use hermesllm::apis::openai_responses::InputParam;
|
||||
use hermesllm::clients::{SupportedAPIsFromClient, SupportedUpstreamAPIs};
|
||||
|
|
@ -14,13 +16,14 @@ use std::sync::Arc;
|
|||
use tokio::sync::RwLock;
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
use crate::router::llm_router::RouterService;
|
||||
use crate::handlers::utils::{create_streaming_response, ObservableStreamProcessor, truncate_message};
|
||||
use crate::handlers::router_chat::router_chat_get_upstream_model;
|
||||
use crate::handlers::utils::{
|
||||
create_streaming_response, truncate_message, ObservableStreamProcessor,
|
||||
};
|
||||
use crate::router::llm_router::RouterService;
|
||||
use crate::state::response_state_processor::ResponsesStateProcessor;
|
||||
use crate::state::{
|
||||
StateStorage, StateStorageError,
|
||||
extract_input_items, retrieve_and_combine_input
|
||||
extract_input_items, retrieve_and_combine_input, StateStorage, StateStorageError,
|
||||
};
|
||||
use crate::tracing::operation_component;
|
||||
|
||||
|
|
@ -39,7 +42,6 @@ pub async fn llm_chat(
|
|||
trace_collector: Arc<TraceCollector>,
|
||||
state_storage: Option<Arc<dyn StateStorage>>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||
|
||||
let request_path = request.uri().path().to_string();
|
||||
let request_headers = request.headers().clone();
|
||||
let request_id = request_headers
|
||||
|
|
@ -74,8 +76,14 @@ pub async fn llm_chat(
|
|||
)) {
|
||||
Ok(request) => request,
|
||||
Err(err) => {
|
||||
warn!("[PLANO_REQ_ID:{}] | FAILURE | Failed to parse request as ProviderRequestType: {}", request_id, err);
|
||||
let err_msg = format!("[PLANO_REQ_ID:{}] | FAILURE | Failed to parse request: {}", request_id, err);
|
||||
warn!(
|
||||
"[PLANO_REQ_ID:{}] | FAILURE | Failed to parse request as ProviderRequestType: {}",
|
||||
request_id, err
|
||||
);
|
||||
let err_msg = format!(
|
||||
"[PLANO_REQ_ID:{}] | FAILURE | Failed to parse request: {}",
|
||||
request_id, err
|
||||
);
|
||||
let mut bad_request = Response::new(full(err_msg));
|
||||
*bad_request.status_mut() = StatusCode::BAD_REQUEST;
|
||||
return Ok(bad_request);
|
||||
|
|
@ -85,7 +93,10 @@ pub async fn llm_chat(
|
|||
// === v1/responses state management: Extract input items early ===
|
||||
let mut original_input_items = Vec::new();
|
||||
let client_api = SupportedAPIsFromClient::from_endpoint(request_path.as_str());
|
||||
let is_responses_api_client = matches!(client_api, Some(SupportedAPIsFromClient::OpenAIResponsesAPI(_)));
|
||||
let is_responses_api_client = matches!(
|
||||
client_api,
|
||||
Some(SupportedAPIsFromClient::OpenAIResponsesAPI(_))
|
||||
);
|
||||
|
||||
// Model alias resolution: update model field in client_request immediately
|
||||
// This ensures all downstream objects use the resolved model
|
||||
|
|
@ -96,20 +107,28 @@ pub async fn llm_chat(
|
|||
|
||||
// Extract tool names and user message preview for span attributes
|
||||
let tool_names = client_request.get_tool_names();
|
||||
let user_message_preview = client_request.get_recent_user_message()
|
||||
let user_message_preview = client_request
|
||||
.get_recent_user_message()
|
||||
.map(|msg| truncate_message(&msg, 50));
|
||||
|
||||
client_request.set_model(resolved_model.clone());
|
||||
if client_request.remove_metadata_key("archgw_preference_config") {
|
||||
debug!("[PLANO_REQ_ID:{}] Removed archgw_preference_config from metadata", request_id);
|
||||
debug!(
|
||||
"[PLANO_REQ_ID:{}] Removed archgw_preference_config from metadata",
|
||||
request_id
|
||||
);
|
||||
}
|
||||
|
||||
// === v1/responses state management: Determine upstream API and combine input if needed ===
|
||||
// Do this BEFORE routing since routing consumes the request
|
||||
// Only process state if state_storage is configured
|
||||
let mut should_manage_state = false;
|
||||
if is_responses_api_client && state_storage.is_some() {
|
||||
if let ProviderRequestType::ResponsesAPIRequest(ref mut responses_req) = client_request {
|
||||
if is_responses_api_client {
|
||||
if let (
|
||||
ProviderRequestType::ResponsesAPIRequest(ref mut responses_req),
|
||||
Some(ref state_store),
|
||||
) = (&mut client_request, &state_storage)
|
||||
{
|
||||
// Extract original input once
|
||||
original_input_items = extract_input_items(&responses_req.input);
|
||||
|
||||
|
|
@ -120,18 +139,22 @@ pub async fn llm_chat(
|
|||
&request_path,
|
||||
&resolved_model,
|
||||
is_streaming_request,
|
||||
).await;
|
||||
)
|
||||
.await;
|
||||
|
||||
let upstream_api = SupportedUpstreamAPIs::from_endpoint(&upstream_path);
|
||||
|
||||
// Only manage state if upstream is NOT OpenAIResponsesAPI (needs translation)
|
||||
should_manage_state = !matches!(upstream_api, Some(SupportedUpstreamAPIs::OpenAIResponsesAPI(_)));
|
||||
should_manage_state = !matches!(
|
||||
upstream_api,
|
||||
Some(SupportedUpstreamAPIs::OpenAIResponsesAPI(_))
|
||||
);
|
||||
|
||||
if should_manage_state {
|
||||
// Retrieve and combine conversation history if previous_response_id exists
|
||||
if let Some(ref prev_resp_id) = responses_req.previous_response_id {
|
||||
match retrieve_and_combine_input(
|
||||
state_storage.as_ref().unwrap().clone(),
|
||||
state_store.clone(),
|
||||
prev_resp_id,
|
||||
original_input_items, // Pass ownership instead of cloning
|
||||
)
|
||||
|
|
@ -166,7 +189,10 @@ pub async fn llm_chat(
|
|||
}
|
||||
}
|
||||
} else {
|
||||
debug!("[PLANO_REQ_ID:{}] | BRIGHT_STAFF | Upstream supports ResponsesAPI natively.", request_id);
|
||||
debug!(
|
||||
"[PLANO_REQ_ID:{}] | BRIGHT_STAFF | Upstream supports ResponsesAPI natively.",
|
||||
request_id
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -177,7 +203,7 @@ pub async fn llm_chat(
|
|||
// Determine routing using the dedicated router_chat module
|
||||
let routing_result = match router_chat_get_upstream_model(
|
||||
router_service,
|
||||
client_request, // Pass the original request - router_chat will convert it
|
||||
client_request, // Pass the original request - router_chat will convert it
|
||||
&request_headers,
|
||||
trace_collector.clone(),
|
||||
&traceparent,
|
||||
|
|
@ -257,7 +283,8 @@ pub async fn llm_chat(
|
|||
user_message_preview,
|
||||
temperature,
|
||||
&llm_providers,
|
||||
).await;
|
||||
)
|
||||
.await;
|
||||
|
||||
// Create base processor for metrics and tracing
|
||||
let base_processor = ObservableStreamProcessor::new(
|
||||
|
|
@ -269,7 +296,11 @@ pub async fn llm_chat(
|
|||
|
||||
// === v1/responses state management: Wrap with ResponsesStateProcessor ===
|
||||
// Only wrap if we need to manage state (client is ResponsesAPI AND upstream is NOT ResponsesAPI AND state_storage is configured)
|
||||
let streaming_response = if should_manage_state && !original_input_items.is_empty() && state_storage.is_some() {
|
||||
let streaming_response = if let (true, false, Some(state_store)) = (
|
||||
should_manage_state,
|
||||
original_input_items.is_empty(),
|
||||
state_storage,
|
||||
) {
|
||||
// Extract Content-Encoding header to handle decompression for state parsing
|
||||
let content_encoding = response_headers
|
||||
.get("content-encoding")
|
||||
|
|
@ -279,7 +310,7 @@ pub async fn llm_chat(
|
|||
// Wrap with state management processor to store state after response completes
|
||||
let state_processor = ResponsesStateProcessor::new(
|
||||
base_processor,
|
||||
state_storage.unwrap(),
|
||||
state_store,
|
||||
original_input_items,
|
||||
resolved_model.clone(),
|
||||
model_name.clone(),
|
||||
|
|
@ -324,6 +355,7 @@ fn resolve_model_alias(
|
|||
}
|
||||
|
||||
/// Builds the LLM span with all required and optional attributes.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn build_llm_span(
|
||||
traceparent: &str,
|
||||
request_path: &str,
|
||||
|
|
@ -337,8 +369,8 @@ async fn build_llm_span(
|
|||
temperature: Option<f32>,
|
||||
llm_providers: &Arc<RwLock<Vec<LlmProvider>>>,
|
||||
) -> common::traces::Span {
|
||||
use common::traces::{SpanBuilder, SpanKind, parse_traceparent};
|
||||
use crate::tracing::{http, llm, OperationNameBuilder};
|
||||
use common::traces::{parse_traceparent, SpanBuilder, SpanKind};
|
||||
|
||||
// Calculate the upstream path based on provider configuration
|
||||
let upstream_path = get_upstream_path(
|
||||
|
|
@ -347,13 +379,14 @@ async fn build_llm_span(
|
|||
request_path,
|
||||
resolved_model,
|
||||
is_streaming,
|
||||
).await;
|
||||
)
|
||||
.await;
|
||||
|
||||
// Build operation name showing path transformation if different
|
||||
let operation_name = if request_path != upstream_path {
|
||||
OperationNameBuilder::new()
|
||||
.with_method("POST")
|
||||
.with_path(&format!("{} >> {}", request_path, upstream_path))
|
||||
.with_path(format!("{} >> {}", request_path, upstream_path))
|
||||
.with_target(resolved_model)
|
||||
.build()
|
||||
} else {
|
||||
|
|
@ -388,7 +421,8 @@ async fn build_llm_span(
|
|||
}
|
||||
|
||||
if let Some(tools) = tool_names {
|
||||
let formatted_tools = tools.iter()
|
||||
let formatted_tools = tools
|
||||
.iter()
|
||||
.map(|name| format!("{}(...)", name))
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
|
|
@ -436,8 +470,7 @@ async fn get_provider_info(
|
|||
|
||||
// First, try to find by model name or provider name
|
||||
let provider = providers_lock.iter().find(|p| {
|
||||
p.model.as_ref().map(|m| m == model_name).unwrap_or(false)
|
||||
|| p.name == model_name
|
||||
p.model.as_ref().map(|m| m == model_name).unwrap_or(false) || p.name == model_name
|
||||
});
|
||||
|
||||
if let Some(provider) = provider {
|
||||
|
|
@ -446,9 +479,7 @@ async fn get_provider_info(
|
|||
return (provider_id, prefix);
|
||||
}
|
||||
|
||||
let default_provider = providers_lock.iter().find(|p| {
|
||||
p.default.unwrap_or(false)
|
||||
});
|
||||
let default_provider = providers_lock.iter().find(|p| p.default.unwrap_or(false));
|
||||
|
||||
if let Some(provider) = default_provider {
|
||||
let provider_id = provider.provider_interface.to_provider_id();
|
||||
|
|
|
|||
|
|
@ -1,13 +1,13 @@
|
|||
pub mod agent_chat_completions;
|
||||
pub mod agent_selector;
|
||||
pub mod llm;
|
||||
pub mod router_chat;
|
||||
pub mod models;
|
||||
pub mod function_calling;
|
||||
pub mod jsonrpc;
|
||||
pub mod llm;
|
||||
pub mod models;
|
||||
pub mod pipeline_processor;
|
||||
pub mod response_handler;
|
||||
pub mod router_chat;
|
||||
pub mod utils;
|
||||
pub mod jsonrpc;
|
||||
|
||||
#[cfg(test)]
|
||||
mod integration_tests;
|
||||
|
|
|
|||
|
|
@ -82,6 +82,7 @@ impl PipelineProcessor {
|
|||
}
|
||||
|
||||
/// Record a span for filter execution
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn record_filter_span(
|
||||
&self,
|
||||
collector: &std::sync::Arc<common::traces::TraceCollector>,
|
||||
|
|
@ -132,6 +133,7 @@ impl PipelineProcessor {
|
|||
}
|
||||
|
||||
/// Record a span for MCP protocol interactions
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn record_agent_filter_span(
|
||||
&self,
|
||||
collector: &std::sync::Arc<common::traces::TraceCollector>,
|
||||
|
|
@ -156,12 +158,12 @@ impl PipelineProcessor {
|
|||
.build();
|
||||
|
||||
let mut span_builder = SpanBuilder::new(&operation_name)
|
||||
.with_span_id(span_id.unwrap_or_else(|| generate_random_span_id()))
|
||||
.with_span_id(span_id.unwrap_or_else(generate_random_span_id))
|
||||
.with_kind(SpanKind::Client)
|
||||
.with_start_time(start_time)
|
||||
.with_end_time(end_time)
|
||||
.with_attribute(http::METHOD, "POST")
|
||||
.with_attribute(http::TARGET, &format!("/mcp ({})", operation.to_string()))
|
||||
.with_attribute(http::TARGET, format!("/mcp ({})", operation))
|
||||
.with_attribute("mcp.operation", operation.to_string())
|
||||
.with_attribute("mcp.agent_id", agent_id.to_string())
|
||||
.with_attribute(
|
||||
|
|
@ -188,6 +190,7 @@ impl PipelineProcessor {
|
|||
}
|
||||
|
||||
/// Process the filter chain of agents (all except the terminal agent)
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub async fn process_filter_chain(
|
||||
&mut self,
|
||||
chat_history: &[Message],
|
||||
|
|
@ -1023,7 +1026,7 @@ mod tests {
|
|||
}
|
||||
});
|
||||
|
||||
let sse_body = format!("event: message\ndata: {}\n\n", rpc_body.to_string());
|
||||
let sse_body = format!("event: message\ndata: {}\n\n", rpc_body);
|
||||
|
||||
let mut server = Server::new_async().await;
|
||||
let _m = server
|
||||
|
|
@ -1061,10 +1064,10 @@ mod tests {
|
|||
.await;
|
||||
|
||||
match result {
|
||||
Err(PipelineError::ClientError { status, body, .. }) => {
|
||||
assert_eq!(status, 400);
|
||||
assert_eq!(body, "bad tool call");
|
||||
}
|
||||
Err(PipelineError::ClientError { status, body, .. }) => {
|
||||
assert_eq!(status, 400);
|
||||
assert_eq!(body, "bad tool call");
|
||||
}
|
||||
_ => panic!("Expected client error when isError flag is set"),
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -133,9 +133,7 @@ impl ResponseHandler {
|
|||
let response_headers = llm_response.headers();
|
||||
let is_sse_streaming = response_headers
|
||||
.get(hyper::header::CONTENT_TYPE)
|
||||
.map_or(false, |v| {
|
||||
v.to_str().unwrap_or("").contains("text/event-stream")
|
||||
});
|
||||
.is_some_and(|v| v.to_str().unwrap_or("").contains("text/event-stream"));
|
||||
|
||||
let response_bytes = llm_response
|
||||
.bytes()
|
||||
|
|
@ -164,7 +162,7 @@ impl ResponseHandler {
|
|||
match transformed_event.provider_response() {
|
||||
Ok(provider_response) => {
|
||||
if let Some(content) = provider_response.content_delta() {
|
||||
accumulated_text.push_str(&content);
|
||||
accumulated_text.push_str(content);
|
||||
} else {
|
||||
info!("No content delta in provider response");
|
||||
}
|
||||
|
|
@ -174,7 +172,7 @@ impl ResponseHandler {
|
|||
}
|
||||
}
|
||||
}
|
||||
return Ok(accumulated_text);
|
||||
Ok(accumulated_text)
|
||||
} else {
|
||||
// If not SSE, treat as regular text response
|
||||
let response_text = String::from_utf8(response_bytes.to_vec()).map_err(|e| {
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
use common::configuration::ModelUsagePreference;
|
||||
use common::consts::{REQUEST_ID_HEADER};
|
||||
use common::traces::{TraceCollector, SpanKind, SpanBuilder, parse_traceparent};
|
||||
use common::consts::REQUEST_ID_HEADER;
|
||||
use common::traces::{parse_traceparent, SpanBuilder, SpanKind, TraceCollector};
|
||||
use hermesllm::clients::endpoints::SupportedUpstreamAPIs;
|
||||
use hermesllm::{ProviderRequest, ProviderRequestType};
|
||||
use hyper::StatusCode;
|
||||
|
|
@ -9,10 +9,10 @@ use std::sync::Arc;
|
|||
use tracing::{debug, info, warn};
|
||||
|
||||
use crate::router::llm_router::RouterService;
|
||||
use crate::tracing::{OperationNameBuilder, operation_component, http, routing};
|
||||
use crate::tracing::{http, operation_component, routing, OperationNameBuilder};
|
||||
|
||||
pub struct RoutingResult {
|
||||
pub model_name: String
|
||||
pub model_name: String,
|
||||
}
|
||||
|
||||
pub struct RoutingError {
|
||||
|
|
@ -24,7 +24,7 @@ impl RoutingError {
|
|||
pub fn internal_error(message: String) -> Self {
|
||||
Self {
|
||||
message,
|
||||
status_code: StatusCode::INTERNAL_SERVER_ERROR
|
||||
status_code: StatusCode::INTERNAL_SERVER_ERROR,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -52,9 +52,7 @@ pub async fn router_chat_get_upstream_model(
|
|||
// Convert to ChatCompletionsRequest for routing (regardless of input type)
|
||||
let chat_request = match ProviderRequestType::try_from((
|
||||
client_request,
|
||||
&SupportedUpstreamAPIs::OpenAIChatCompletions(
|
||||
hermesllm::apis::OpenAIApi::ChatCompletions,
|
||||
),
|
||||
&SupportedUpstreamAPIs::OpenAIChatCompletions(hermesllm::apis::OpenAIApi::ChatCompletions),
|
||||
)) {
|
||||
Ok(ProviderRequestType::ChatCompletionsRequest(req)) => req,
|
||||
Ok(
|
||||
|
|
@ -69,7 +67,10 @@ pub async fn router_chat_get_upstream_model(
|
|||
));
|
||||
}
|
||||
Err(err) => {
|
||||
warn!("Failed to convert request to ChatCompletionsRequest: {}", err);
|
||||
warn!(
|
||||
"Failed to convert request to ChatCompletionsRequest: {}",
|
||||
err
|
||||
);
|
||||
return Err(RoutingError::internal_error(format!(
|
||||
"Failed to convert request: {}",
|
||||
err
|
||||
|
|
@ -151,9 +152,7 @@ pub async fn router_chat_get_upstream_model(
|
|||
)
|
||||
.await;
|
||||
|
||||
Ok(RoutingResult {
|
||||
model_name
|
||||
})
|
||||
Ok(RoutingResult { model_name })
|
||||
}
|
||||
None => {
|
||||
// No route determined, use default model from request
|
||||
|
|
@ -176,7 +175,7 @@ pub async fn router_chat_get_upstream_model(
|
|||
.await;
|
||||
|
||||
Ok(RoutingResult {
|
||||
model_name: default_model
|
||||
model_name: default_model,
|
||||
})
|
||||
}
|
||||
},
|
||||
|
|
@ -194,9 +193,10 @@ pub async fn router_chat_get_upstream_model(
|
|||
)
|
||||
.await;
|
||||
|
||||
Err(RoutingError::internal_error(
|
||||
format!("Failed to determine route: {}", err)
|
||||
))
|
||||
Err(RoutingError::internal_error(format!(
|
||||
"Failed to determine route: {}",
|
||||
err
|
||||
)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -230,7 +230,10 @@ async fn record_routing_span(
|
|||
.with_end_time(std::time::SystemTime::now())
|
||||
.with_attribute(http::METHOD, "POST")
|
||||
.with_attribute(http::TARGET, routing_api_path.to_string())
|
||||
.with_attribute(routing::ROUTE_DETERMINATION_MS, start_time.elapsed().as_millis().to_string());
|
||||
.with_attribute(
|
||||
routing::ROUTE_DETERMINATION_MS,
|
||||
start_time.elapsed().as_millis().to_string(),
|
||||
);
|
||||
|
||||
// Only set parent span ID if it exists (not a root span)
|
||||
if let Some(parent) = parent_span_id {
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
use bytes::Bytes;
|
||||
use common::traces::{Span, Attribute, AttributeValue, TraceCollector, Event};
|
||||
use common::traces::{Attribute, AttributeValue, Event, Span, TraceCollector};
|
||||
use http_body_util::combinators::BoxBody;
|
||||
use http_body_util::StreamBody;
|
||||
use hyper::body::Frame;
|
||||
|
|
@ -11,7 +11,7 @@ use tokio_stream::StreamExt;
|
|||
use tracing::warn;
|
||||
|
||||
// Import tracing constants
|
||||
use crate::tracing::{llm, error};
|
||||
use crate::tracing::{error, llm};
|
||||
|
||||
/// Trait for processing streaming chunks
|
||||
/// Implementors can inject custom logic during streaming (e.g., hallucination detection, logging)
|
||||
|
|
@ -97,7 +97,6 @@ impl StreamProcessor for ObservableStreamProcessor {
|
|||
},
|
||||
});
|
||||
|
||||
|
||||
self.span.attributes.push(Attribute {
|
||||
key: llm::DURATION_MS.to_string(),
|
||||
value: AttributeValue {
|
||||
|
|
@ -119,11 +118,9 @@ impl StreamProcessor for ObservableStreamProcessor {
|
|||
if let Ok(start_time_nanos) = self.span.start_time_unix_nano.parse::<u128>() {
|
||||
// Convert ttft from milliseconds to nanoseconds and add to start time
|
||||
let event_timestamp = start_time_nanos + (ttft * 1_000_000);
|
||||
let mut event = Event::new(llm::TIME_TO_FIRST_TOKEN_MS.to_string(), event_timestamp);
|
||||
event.add_attribute(
|
||||
llm::TIME_TO_FIRST_TOKEN_MS.to_string(),
|
||||
ttft.to_string(),
|
||||
);
|
||||
let mut event =
|
||||
Event::new(llm::TIME_TO_FIRST_TOKEN_MS.to_string(), event_timestamp);
|
||||
event.add_attribute(llm::TIME_TO_FIRST_TOKEN_MS.to_string(), ttft.to_string());
|
||||
|
||||
// Initialize events vector if needed
|
||||
if self.span.events.is_none() {
|
||||
|
|
@ -137,7 +134,8 @@ impl StreamProcessor for ObservableStreamProcessor {
|
|||
}
|
||||
|
||||
// Record the finalized span
|
||||
self.collector.record_span(&self.service_name, self.span.clone());
|
||||
self.collector
|
||||
.record_span(&self.service_name, self.span.clone());
|
||||
}
|
||||
|
||||
fn on_error(&mut self, error_msg: &str) {
|
||||
|
|
@ -173,7 +171,8 @@ impl StreamProcessor for ObservableStreamProcessor {
|
|||
});
|
||||
|
||||
// Record the error span
|
||||
self.collector.record_span(&self.service_name, self.span.clone());
|
||||
self.collector
|
||||
.record_span(&self.service_name, self.span.clone());
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -4,13 +4,15 @@ use brightstaff::handlers::llm::llm_chat;
|
|||
use brightstaff::handlers::models::list_models;
|
||||
use brightstaff::router::llm_router::RouterService;
|
||||
use brightstaff::router::plano_orchestrator::OrchestratorService;
|
||||
use brightstaff::state::StateStorage;
|
||||
use brightstaff::state::postgresql::PostgreSQLConversationStorage;
|
||||
use brightstaff::state::memory::MemoryConversationalStorage;
|
||||
use brightstaff::state::postgresql::PostgreSQLConversationStorage;
|
||||
use brightstaff::state::StateStorage;
|
||||
use brightstaff::utils::tracing::init_tracer;
|
||||
use bytes::Bytes;
|
||||
use common::configuration::{Agent, Configuration};
|
||||
use common::consts::{CHAT_COMPLETIONS_PATH, MESSAGES_PATH, OPENAI_RESPONSES_API_PATH, PLANO_ORCHESTRATOR_MODEL_NAME};
|
||||
use common::consts::{
|
||||
CHAT_COMPLETIONS_PATH, MESSAGES_PATH, OPENAI_RESPONSES_API_PATH, PLANO_ORCHESTRATOR_MODEL_NAME,
|
||||
};
|
||||
use common::traces::TraceCollector;
|
||||
use http_body_util::{combinators::BoxBody, BodyExt, Empty};
|
||||
use hyper::body::Incoming;
|
||||
|
|
@ -105,7 +107,6 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
PLANO_ORCHESTRATOR_MODEL_NAME.to_string(),
|
||||
));
|
||||
|
||||
|
||||
let model_aliases = Arc::new(arch_config.model_aliases.clone());
|
||||
|
||||
// Initialize trace collector and start background flusher
|
||||
|
|
@ -127,33 +128,33 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
// Configurable via arch_config.yaml state_storage section
|
||||
// If not configured, state management is disabled
|
||||
// Environment variables are substituted by envsubst before config is read
|
||||
let state_storage: Option<Arc<dyn StateStorage>> = if let Some(storage_config) = &arch_config.state_storage {
|
||||
let storage: Arc<dyn StateStorage> = match storage_config.storage_type {
|
||||
common::configuration::StateStorageType::Memory => {
|
||||
info!("Initialized conversation state storage: Memory");
|
||||
Arc::new(MemoryConversationalStorage::new())
|
||||
}
|
||||
common::configuration::StateStorageType::Postgres => {
|
||||
let connection_string = storage_config
|
||||
.connection_string
|
||||
.as_ref()
|
||||
.expect("connection_string is required for postgres state_storage");
|
||||
let state_storage: Option<Arc<dyn StateStorage>> =
|
||||
if let Some(storage_config) = &arch_config.state_storage {
|
||||
let storage: Arc<dyn StateStorage> = match storage_config.storage_type {
|
||||
common::configuration::StateStorageType::Memory => {
|
||||
info!("Initialized conversation state storage: Memory");
|
||||
Arc::new(MemoryConversationalStorage::new())
|
||||
}
|
||||
common::configuration::StateStorageType::Postgres => {
|
||||
let connection_string = storage_config
|
||||
.connection_string
|
||||
.as_ref()
|
||||
.expect("connection_string is required for postgres state_storage");
|
||||
|
||||
debug!("Postgres connection string (full): {}", connection_string);
|
||||
info!("Initializing conversation state storage: Postgres");
|
||||
Arc::new(
|
||||
PostgreSQLConversationStorage::new(connection_string.clone())
|
||||
.await
|
||||
.expect("Failed to initialize Postgres state storage"),
|
||||
)
|
||||
}
|
||||
debug!("Postgres connection string (full): {}", connection_string);
|
||||
info!("Initializing conversation state storage: Postgres");
|
||||
Arc::new(
|
||||
PostgreSQLConversationStorage::new(connection_string.clone())
|
||||
.await
|
||||
.expect("Failed to initialize Postgres state storage"),
|
||||
)
|
||||
}
|
||||
};
|
||||
Some(storage)
|
||||
} else {
|
||||
info!("No state_storage configured - conversation state management disabled");
|
||||
None
|
||||
};
|
||||
Some(storage)
|
||||
} else {
|
||||
info!("No state_storage configured - conversation state management disabled");
|
||||
None
|
||||
};
|
||||
|
||||
|
||||
loop {
|
||||
let (stream, _) = listener.accept().await?;
|
||||
|
|
@ -208,12 +209,22 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
}
|
||||
}
|
||||
match (req.method(), path) {
|
||||
(&Method::POST, CHAT_COMPLETIONS_PATH | MESSAGES_PATH | OPENAI_RESPONSES_API_PATH) => {
|
||||
let fully_qualified_url =
|
||||
format!("{}{}", llm_provider_url, path);
|
||||
llm_chat(req, router_service, fully_qualified_url, model_aliases, llm_providers, trace_collector, state_storage)
|
||||
.with_context(parent_cx)
|
||||
.await
|
||||
(
|
||||
&Method::POST,
|
||||
CHAT_COMPLETIONS_PATH | MESSAGES_PATH | OPENAI_RESPONSES_API_PATH,
|
||||
) => {
|
||||
let fully_qualified_url = format!("{}{}", llm_provider_url, path);
|
||||
llm_chat(
|
||||
req,
|
||||
router_service,
|
||||
fully_qualified_url,
|
||||
model_aliases,
|
||||
llm_providers,
|
||||
trace_collector,
|
||||
state_storage,
|
||||
)
|
||||
.with_context(parent_cx)
|
||||
.await
|
||||
}
|
||||
(&Method::POST, "/function_calling") => {
|
||||
let fully_qualified_url =
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ use std::collections::HashMap;
|
|||
|
||||
use common::configuration::{AgentUsagePreference, OrchestrationPreference};
|
||||
use hermesllm::apis::openai::{ChatCompletionsRequest, Message, MessageContent, Role};
|
||||
use serde::{Deserialize, Serialize, ser::Serialize as SerializeTrait};
|
||||
use serde::{ser::Serialize as SerializeTrait, Deserialize, Serialize};
|
||||
use tracing::{debug, warn};
|
||||
|
||||
use super::orchestrator_model::{OrchestratorModel, OrchestratorModelError};
|
||||
|
|
@ -144,7 +144,7 @@ impl OrchestratorModelV1 {
|
|||
// Format routes: each route as JSON on its own line with standard spacing
|
||||
let agent_orchestration_json_str = agent_orchestration_values
|
||||
.iter()
|
||||
.map(|pref| to_spaced_json(pref))
|
||||
.map(to_spaced_json)
|
||||
.collect::<Vec<String>>()
|
||||
.join("\n");
|
||||
let agent_orchestration_to_model_map: HashMap<String, String> = agent_orchestrations
|
||||
|
|
@ -238,24 +238,26 @@ impl OrchestratorModel for OrchestratorModelV1 {
|
|||
let selected_conversation_list = selected_messages_list_reversed
|
||||
.iter()
|
||||
.rev()
|
||||
.map(|message| {
|
||||
Message {
|
||||
role: message.role.clone(),
|
||||
content: MessageContent::Text(message.content.to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
}
|
||||
.map(|message| Message {
|
||||
role: message.role.clone(),
|
||||
content: MessageContent::Text(message.content.to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
})
|
||||
.collect::<Vec<Message>>();
|
||||
|
||||
// Generate the orchestrator request message based on the usage preferences.
|
||||
// If preferences are passed in request then we use them;
|
||||
// Otherwise, we use the default orchestration modelpreferences.
|
||||
let orchestrator_message = match convert_to_orchestrator_preferences(usage_preferences_from_request) {
|
||||
Some(prefs) => generate_orchestrator_message(&prefs, &selected_conversation_list),
|
||||
None => generate_orchestrator_message(&self.agent_orchestration_json_str, &selected_conversation_list),
|
||||
};
|
||||
let orchestrator_message =
|
||||
match convert_to_orchestrator_preferences(usage_preferences_from_request) {
|
||||
Some(prefs) => generate_orchestrator_message(&prefs, &selected_conversation_list),
|
||||
None => generate_orchestrator_message(
|
||||
&self.agent_orchestration_json_str,
|
||||
&selected_conversation_list,
|
||||
),
|
||||
};
|
||||
|
||||
ChatCompletionsRequest {
|
||||
model: self.orchestration_model.clone(),
|
||||
|
|
@ -280,7 +282,8 @@ impl OrchestratorModel for OrchestratorModelV1 {
|
|||
return Ok(None);
|
||||
}
|
||||
let orchestrator_resp_fixed = fix_json_response(content);
|
||||
let orchestrator_response: AgentOrchestratorResponse = serde_json::from_str(orchestrator_resp_fixed.as_str())?;
|
||||
let orchestrator_response: AgentOrchestratorResponse =
|
||||
serde_json::from_str(orchestrator_resp_fixed.as_str())?;
|
||||
|
||||
let selected_routes = orchestrator_response.route.unwrap_or_default();
|
||||
|
||||
|
|
@ -320,7 +323,11 @@ impl OrchestratorModel for OrchestratorModelV1 {
|
|||
} else {
|
||||
// If no usage preferences are passed in request then use the default orchestration model preferences
|
||||
for selected_route in valid_routes {
|
||||
if let Some(model) = self.agent_orchestration_to_model_map.get(&selected_route).cloned() {
|
||||
if let Some(model) = self
|
||||
.agent_orchestration_to_model_map
|
||||
.get(&selected_route)
|
||||
.cloned()
|
||||
{
|
||||
result.push((selected_route, model));
|
||||
} else {
|
||||
warn!(
|
||||
|
|
@ -375,7 +382,7 @@ fn convert_to_orchestrator_preferences(
|
|||
// Format routes: each route as JSON on its own line with standard spacing
|
||||
let routes_str = orchestration_preferences
|
||||
.iter()
|
||||
.map(|pref| to_spaced_json(pref))
|
||||
.map(to_spaced_json)
|
||||
.collect::<Vec<String>>()
|
||||
.join("\n");
|
||||
|
||||
|
|
@ -425,7 +432,10 @@ mod tests {
|
|||
// CRITICAL: Test that colons inside string values are NOT modified
|
||||
let with_colon = serde_json::json!({"name": "foo:bar", "url": "http://example.com"});
|
||||
let result = to_spaced_json(&with_colon);
|
||||
assert_eq!(result, r#"{"name": "foo:bar", "url": "http://example.com"}"#);
|
||||
assert_eq!(
|
||||
result,
|
||||
r#"{"name": "foo:bar", "url": "http://example.com"}"#
|
||||
);
|
||||
|
||||
// Test empty object and array
|
||||
let empty_obj = serde_json::json!({});
|
||||
|
|
@ -446,7 +456,8 @@ mod tests {
|
|||
});
|
||||
let result = to_spaced_json(&complex);
|
||||
// Verify URLs with colons are preserved correctly
|
||||
assert!(result.contains(r#""urls": ["https://api.example.com:8080/path", "file:///local/path"]"#));
|
||||
assert!(result
|
||||
.contains(r#""urls": ["https://api.example.com:8080/path", "file:///local/path"]"#));
|
||||
// Verify spacing format
|
||||
assert!(result.contains(r#""type": "object""#));
|
||||
assert!(result.contains(r#""properties": {}"#));
|
||||
|
|
@ -497,10 +508,16 @@ If no routes are needed, return an empty list for `route`.
|
|||
]
|
||||
}
|
||||
"#;
|
||||
let agent_orchestrations =
|
||||
serde_json::from_str::<HashMap<String, Vec<OrchestrationPreference>>>(orchestrations_str).unwrap();
|
||||
let agent_orchestrations = serde_json::from_str::<
|
||||
HashMap<String, Vec<OrchestrationPreference>>,
|
||||
>(orchestrations_str)
|
||||
.unwrap();
|
||||
let orchestration_model = "test-model".to_string();
|
||||
let orchestrator = OrchestratorModelV1::new(agent_orchestrations, orchestration_model.clone(), usize::MAX);
|
||||
let orchestrator = OrchestratorModelV1::new(
|
||||
agent_orchestrations,
|
||||
orchestration_model.clone(),
|
||||
usize::MAX,
|
||||
);
|
||||
|
||||
let conversation_str = r#"
|
||||
[
|
||||
|
|
@ -568,7 +585,11 @@ If no routes are needed, return an empty list for `route`.
|
|||
// Empty orchestrations map - not used when usage_preferences are provided
|
||||
let agent_orchestrations: HashMap<String, Vec<OrchestrationPreference>> = HashMap::new();
|
||||
let orchestration_model = "test-model".to_string();
|
||||
let orchestrator = OrchestratorModelV1::new(agent_orchestrations, orchestration_model.clone(), usize::MAX);
|
||||
let orchestrator = OrchestratorModelV1::new(
|
||||
agent_orchestrations,
|
||||
orchestration_model.clone(),
|
||||
usize::MAX,
|
||||
);
|
||||
|
||||
let conversation_str = r#"
|
||||
[
|
||||
|
|
@ -640,10 +661,13 @@ If no routes are needed, return an empty list for `route`.
|
|||
]
|
||||
}
|
||||
"#;
|
||||
let agent_orchestrations =
|
||||
serde_json::from_str::<HashMap<String, Vec<OrchestrationPreference>>>(orchestrations_str).unwrap();
|
||||
let agent_orchestrations = serde_json::from_str::<
|
||||
HashMap<String, Vec<OrchestrationPreference>>,
|
||||
>(orchestrations_str)
|
||||
.unwrap();
|
||||
let orchestration_model = "test-model".to_string();
|
||||
let orchestrator = OrchestratorModelV1::new(agent_orchestrations, orchestration_model.clone(), 235);
|
||||
let orchestrator =
|
||||
OrchestratorModelV1::new(agent_orchestrations, orchestration_model.clone(), 235);
|
||||
|
||||
let conversation_str = r#"
|
||||
[
|
||||
|
|
@ -709,11 +733,14 @@ If no routes are needed, return an empty list for `route`.
|
|||
]
|
||||
}
|
||||
"#;
|
||||
let agent_orchestrations =
|
||||
serde_json::from_str::<HashMap<String, Vec<OrchestrationPreference>>>(orchestrations_str).unwrap();
|
||||
let agent_orchestrations = serde_json::from_str::<
|
||||
HashMap<String, Vec<OrchestrationPreference>>,
|
||||
>(orchestrations_str)
|
||||
.unwrap();
|
||||
|
||||
let orchestration_model = "test-model".to_string();
|
||||
let orchestrator = OrchestratorModelV1::new(agent_orchestrations, orchestration_model.clone(), 200);
|
||||
let orchestrator =
|
||||
OrchestratorModelV1::new(agent_orchestrations, orchestration_model.clone(), 200);
|
||||
|
||||
let conversation_str = r#"
|
||||
[
|
||||
|
|
@ -787,10 +814,13 @@ If no routes are needed, return an empty list for `route`.
|
|||
]
|
||||
}
|
||||
"#;
|
||||
let agent_orchestrations =
|
||||
serde_json::from_str::<HashMap<String, Vec<OrchestrationPreference>>>(orchestrations_str).unwrap();
|
||||
let agent_orchestrations = serde_json::from_str::<
|
||||
HashMap<String, Vec<OrchestrationPreference>>,
|
||||
>(orchestrations_str)
|
||||
.unwrap();
|
||||
let orchestration_model = "test-model".to_string();
|
||||
let orchestrator = OrchestratorModelV1::new(agent_orchestrations, orchestration_model.clone(), 230);
|
||||
let orchestrator =
|
||||
OrchestratorModelV1::new(agent_orchestrations, orchestration_model.clone(), 230);
|
||||
|
||||
let conversation_str = r#"
|
||||
[
|
||||
|
|
@ -871,10 +901,16 @@ If no routes are needed, return an empty list for `route`.
|
|||
]
|
||||
}
|
||||
"#;
|
||||
let agent_orchestrations =
|
||||
serde_json::from_str::<HashMap<String, Vec<OrchestrationPreference>>>(orchestrations_str).unwrap();
|
||||
let agent_orchestrations = serde_json::from_str::<
|
||||
HashMap<String, Vec<OrchestrationPreference>>,
|
||||
>(orchestrations_str)
|
||||
.unwrap();
|
||||
let orchestration_model = "test-model".to_string();
|
||||
let orchestrator = OrchestratorModelV1::new(agent_orchestrations, orchestration_model.clone(), usize::MAX);
|
||||
let orchestrator = OrchestratorModelV1::new(
|
||||
agent_orchestrations,
|
||||
orchestration_model.clone(),
|
||||
usize::MAX,
|
||||
);
|
||||
|
||||
let conversation_str = r#"
|
||||
[
|
||||
|
|
@ -957,10 +993,16 @@ If no routes are needed, return an empty list for `route`.
|
|||
]
|
||||
}
|
||||
"#;
|
||||
let agent_orchestrations =
|
||||
serde_json::from_str::<HashMap<String, Vec<OrchestrationPreference>>>(orchestrations_str).unwrap();
|
||||
let agent_orchestrations = serde_json::from_str::<
|
||||
HashMap<String, Vec<OrchestrationPreference>>,
|
||||
>(orchestrations_str)
|
||||
.unwrap();
|
||||
let orchestration_model = "test-model".to_string();
|
||||
let orchestrator = OrchestratorModelV1::new(agent_orchestrations, orchestration_model.clone(), usize::MAX);
|
||||
let orchestrator = OrchestratorModelV1::new(
|
||||
agent_orchestrations,
|
||||
orchestration_model.clone(),
|
||||
usize::MAX,
|
||||
);
|
||||
|
||||
let conversation_str = r#"
|
||||
[
|
||||
|
|
@ -1034,10 +1076,13 @@ If no routes are needed, return an empty list for `route`.
|
|||
]
|
||||
}
|
||||
"#;
|
||||
let agent_orchestrations =
|
||||
serde_json::from_str::<HashMap<String, Vec<OrchestrationPreference>>>(orchestrations_str).unwrap();
|
||||
let agent_orchestrations = serde_json::from_str::<
|
||||
HashMap<String, Vec<OrchestrationPreference>>,
|
||||
>(orchestrations_str)
|
||||
.unwrap();
|
||||
|
||||
let orchestrator = OrchestratorModelV1::new(agent_orchestrations, "test-model".to_string(), 2000);
|
||||
let orchestrator =
|
||||
OrchestratorModelV1::new(agent_orchestrations, "test-model".to_string(), 2000);
|
||||
|
||||
// Case 1: Valid JSON with single route in array
|
||||
let input = r#"{"route": ["Image generation"]}"#;
|
||||
|
|
|
|||
|
|
@ -34,10 +34,7 @@ pub enum OrchestrationError {
|
|||
pub type Result<T> = std::result::Result<T, OrchestrationError>;
|
||||
|
||||
impl OrchestratorService {
|
||||
pub fn new(
|
||||
orchestrator_url: String,
|
||||
orchestration_model_name: String,
|
||||
) -> Self {
|
||||
pub fn new(orchestrator_url: String, orchestration_model_name: String) -> Self {
|
||||
// Empty agent orchestrations - will be provided via usage_preferences in requests
|
||||
let agent_orchestrations: HashMap<String, Vec<OrchestrationPreference>> = HashMap::new();
|
||||
|
||||
|
|
|
|||
|
|
@ -85,13 +85,19 @@ impl StateStorage for MemoryConversationalStorage {
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use hermesllm::apis::openai_responses::{InputItem, InputMessage, MessageRole, InputContent, MessageContent};
|
||||
use hermesllm::apis::openai_responses::{
|
||||
InputContent, InputItem, InputMessage, MessageContent, MessageRole,
|
||||
};
|
||||
|
||||
fn create_test_state(response_id: &str, num_messages: usize) -> OpenAIConversationState {
|
||||
let mut input_items = Vec::new();
|
||||
for i in 0..num_messages {
|
||||
input_items.push(InputItem::Message(InputMessage {
|
||||
role: if i % 2 == 0 { MessageRole::User } else { MessageRole::Assistant },
|
||||
role: if i % 2 == 0 {
|
||||
MessageRole::User
|
||||
} else {
|
||||
MessageRole::Assistant
|
||||
},
|
||||
content: MessageContent::Items(vec![InputContent::InputText {
|
||||
text: format!("Message {}", i),
|
||||
}]),
|
||||
|
|
@ -252,7 +258,9 @@ mod tests {
|
|||
let merged = storage.merge(&prev_state, current_input);
|
||||
|
||||
// Verify order: prev messages first, then current
|
||||
let InputItem::Message(msg) = &merged[0] else { panic!("Expected Message") };
|
||||
let InputItem::Message(msg) = &merged[0] else {
|
||||
panic!("Expected Message")
|
||||
};
|
||||
match &msg.content {
|
||||
MessageContent::Items(items) => match &items[0] {
|
||||
InputContent::InputText { text } => assert_eq!(text, "Message 0"),
|
||||
|
|
@ -261,7 +269,9 @@ mod tests {
|
|||
_ => panic!("Expected MessageContent::Items"),
|
||||
}
|
||||
|
||||
let InputItem::Message(msg) = &merged[2] else { panic!("Expected Message") };
|
||||
let InputItem::Message(msg) = &merged[2] else {
|
||||
panic!("Expected Message")
|
||||
};
|
||||
match &msg.content {
|
||||
MessageContent::Items(items) => match &items[0] {
|
||||
InputContent::InputText { text } => assert_eq!(text, "Message 2"),
|
||||
|
|
@ -404,7 +414,8 @@ mod tests {
|
|||
let current_input = vec![InputItem::Message(InputMessage {
|
||||
role: MessageRole::User,
|
||||
content: MessageContent::Items(vec![InputContent::InputText {
|
||||
text: "Function result: {\"temperature\": 72, \"condition\": \"sunny\"}".to_string(),
|
||||
text: "Function result: {\"temperature\": 72, \"condition\": \"sunny\"}"
|
||||
.to_string(),
|
||||
}]),
|
||||
})];
|
||||
|
||||
|
|
@ -415,7 +426,9 @@ mod tests {
|
|||
assert_eq!(merged.len(), 3);
|
||||
|
||||
// Verify the order and content
|
||||
let InputItem::Message(msg1) = &merged[0] else { panic!("Expected Message") };
|
||||
let InputItem::Message(msg1) = &merged[0] else {
|
||||
panic!("Expected Message")
|
||||
};
|
||||
assert!(matches!(msg1.role, MessageRole::User));
|
||||
match &msg1.content {
|
||||
MessageContent::Items(items) => match &items[0] {
|
||||
|
|
@ -427,7 +440,9 @@ mod tests {
|
|||
_ => panic!("Expected MessageContent::Items"),
|
||||
}
|
||||
|
||||
let InputItem::Message(msg2) = &merged[1] else { panic!("Expected Message") };
|
||||
let InputItem::Message(msg2) = &merged[1] else {
|
||||
panic!("Expected Message")
|
||||
};
|
||||
assert!(matches!(msg2.role, MessageRole::Assistant));
|
||||
match &msg2.content {
|
||||
MessageContent::Items(items) => match &items[0] {
|
||||
|
|
@ -439,7 +454,9 @@ mod tests {
|
|||
_ => panic!("Expected MessageContent::Items"),
|
||||
}
|
||||
|
||||
let InputItem::Message(msg3) = &merged[2] else { panic!("Expected Message") };
|
||||
let InputItem::Message(msg3) = &merged[2] else {
|
||||
panic!("Expected Message")
|
||||
};
|
||||
assert!(matches!(msg3.role, MessageRole::User));
|
||||
match &msg3.content {
|
||||
MessageContent::Items(items) => match &items[0] {
|
||||
|
|
@ -508,11 +525,15 @@ mod tests {
|
|||
assert_eq!(merged.len(), 5);
|
||||
|
||||
// Verify first item is original user message
|
||||
let InputItem::Message(first) = &merged[0] else { panic!("Expected Message") };
|
||||
let InputItem::Message(first) = &merged[0] else {
|
||||
panic!("Expected Message")
|
||||
};
|
||||
assert!(matches!(first.role, MessageRole::User));
|
||||
|
||||
// Verify last two are function outputs
|
||||
let InputItem::Message(second_last) = &merged[3] else { panic!("Expected Message") };
|
||||
let InputItem::Message(second_last) = &merged[3] else {
|
||||
panic!("Expected Message")
|
||||
};
|
||||
assert!(matches!(second_last.role, MessageRole::User));
|
||||
match &second_last.content {
|
||||
MessageContent::Items(items) => match &items[0] {
|
||||
|
|
@ -522,7 +543,9 @@ mod tests {
|
|||
_ => panic!("Expected MessageContent::Items"),
|
||||
}
|
||||
|
||||
let InputItem::Message(last) = &merged[4] else { panic!("Expected Message") };
|
||||
let InputItem::Message(last) = &merged[4] else {
|
||||
panic!("Expected Message")
|
||||
};
|
||||
assert!(matches!(last.role, MessageRole::User));
|
||||
match &last.content {
|
||||
MessageContent::Items(items) => match &items[0] {
|
||||
|
|
@ -590,7 +613,9 @@ mod tests {
|
|||
assert_eq!(merged.len(), 5);
|
||||
|
||||
// Verify the entire conversation flow is preserved
|
||||
let InputItem::Message(first) = &merged[0] else { panic!("Expected Message") };
|
||||
let InputItem::Message(first) = &merged[0] else {
|
||||
panic!("Expected Message")
|
||||
};
|
||||
match &first.content {
|
||||
MessageContent::Items(items) => match &items[0] {
|
||||
InputContent::InputText { text } => assert!(text.contains("What's the weather")),
|
||||
|
|
@ -599,7 +624,9 @@ mod tests {
|
|||
_ => panic!("Expected MessageContent::Items"),
|
||||
}
|
||||
|
||||
let InputItem::Message(last) = &merged[4] else { panic!("Expected Message") };
|
||||
let InputItem::Message(last) = &merged[4] else {
|
||||
panic!("Expected Message")
|
||||
};
|
||||
match &last.content {
|
||||
MessageContent::Items(items) => match &items[0] {
|
||||
InputContent::InputText { text } => assert!(text.contains("umbrella")),
|
||||
|
|
|
|||
|
|
@ -1,14 +1,16 @@
|
|||
use async_trait::async_trait;
|
||||
use hermesllm::apis::openai_responses::{InputItem, InputMessage, InputContent, MessageContent, MessageRole, InputParam};
|
||||
use hermesllm::apis::openai_responses::{
|
||||
InputContent, InputItem, InputMessage, InputParam, MessageContent, MessageRole,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::error::Error;
|
||||
use std::fmt;
|
||||
use std::sync::Arc;
|
||||
use tracing::{debug};
|
||||
use tracing::debug;
|
||||
|
||||
pub mod memory;
|
||||
pub mod response_state_processor;
|
||||
pub mod postgresql;
|
||||
pub mod response_state_processor;
|
||||
|
||||
/// Represents the conversational state for a v1/responses request
|
||||
/// Contains the complete input/output history that can be restored
|
||||
|
|
@ -47,7 +49,9 @@ pub enum StateStorageError {
|
|||
impl fmt::Display for StateStorageError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
StateStorageError::NotFound(id) => write!(f, "Conversation state not found for response_id: {}", id),
|
||||
StateStorageError::NotFound(id) => {
|
||||
write!(f, "Conversation state not found for response_id: {}", id)
|
||||
}
|
||||
StateStorageError::StorageError(msg) => write!(f, "Storage error: {}", msg),
|
||||
StateStorageError::SerializationError(msg) => write!(f, "Serialization error: {}", msg),
|
||||
}
|
||||
|
|
@ -96,8 +100,6 @@ pub trait StateStorage: Send + Sync {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
/// Storage backend type enum
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum StorageBackend {
|
||||
|
|
@ -106,7 +108,7 @@ pub enum StorageBackend {
|
|||
}
|
||||
|
||||
impl StorageBackend {
|
||||
pub fn from_str(s: &str) -> Option<Self> {
|
||||
pub fn parse_backend(s: &str) -> Option<Self> {
|
||||
match s.to_lowercase().as_str() {
|
||||
"memory" => Some(StorageBackend::Memory),
|
||||
"supabase" => Some(StorageBackend::Supabase),
|
||||
|
|
@ -139,7 +141,6 @@ pub async fn retrieve_and_combine_input(
|
|||
previous_response_id: &str,
|
||||
current_input: Vec<InputItem>,
|
||||
) -> Result<Vec<InputItem>, StateStorageError> {
|
||||
|
||||
// First get the previous state
|
||||
let prev_state = storage.get(previous_response_id).await?;
|
||||
let combined_input = storage.merge(&prev_state, current_input);
|
||||
|
|
|
|||
|
|
@ -149,13 +149,12 @@ impl StateStorage for PostgreSQLConversationStorage {
|
|||
let provider: String = row.get("provider");
|
||||
|
||||
// Deserialize input_items from JSONB
|
||||
let input_items =
|
||||
serde_json::from_value(input_items_json).map_err(|e| {
|
||||
StateStorageError::StorageError(format!(
|
||||
"Failed to deserialize input_items: {}",
|
||||
e
|
||||
))
|
||||
})?;
|
||||
let input_items = serde_json::from_value(input_items_json).map_err(|e| {
|
||||
StateStorageError::StorageError(format!(
|
||||
"Failed to deserialize input_items: {}",
|
||||
e
|
||||
))
|
||||
})?;
|
||||
|
||||
Ok(OpenAIConversationState {
|
||||
response_id,
|
||||
|
|
@ -230,7 +229,9 @@ Run that SQL file against your database before using this storage backend.
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use hermesllm::apis::openai_responses::{InputContent, InputItem, InputMessage, MessageContent, MessageRole};
|
||||
use hermesllm::apis::openai_responses::{
|
||||
InputContent, InputItem, InputMessage, MessageContent, MessageRole,
|
||||
};
|
||||
|
||||
fn create_test_state(response_id: &str) -> OpenAIConversationState {
|
||||
OpenAIConversationState {
|
||||
|
|
@ -320,7 +321,10 @@ mod tests {
|
|||
|
||||
let result = storage.get("nonexistent_id").await;
|
||||
assert!(result.is_err());
|
||||
assert!(matches!(result.unwrap_err(), StateStorageError::NotFound(_)));
|
||||
assert!(matches!(
|
||||
result.unwrap_err(),
|
||||
StateStorageError::NotFound(_)
|
||||
));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
|
@ -372,7 +376,10 @@ mod tests {
|
|||
|
||||
let result = storage.delete("nonexistent_id").await;
|
||||
assert!(result.is_err());
|
||||
assert!(matches!(result.unwrap_err(), StateStorageError::NotFound(_)));
|
||||
assert!(matches!(
|
||||
result.unwrap_err(),
|
||||
StateStorageError::NotFound(_)
|
||||
));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
|
@ -423,9 +430,13 @@ mod tests {
|
|||
|
||||
println!("✅ Data written to Supabase!");
|
||||
println!("Check your Supabase dashboard:");
|
||||
println!(" SELECT * FROM conversation_states WHERE response_id = 'manual_test_verification';");
|
||||
println!(
|
||||
" SELECT * FROM conversation_states WHERE response_id = 'manual_test_verification';"
|
||||
);
|
||||
println!("\nTo cleanup, run:");
|
||||
println!(" DELETE FROM conversation_states WHERE response_id = 'manual_test_verification';");
|
||||
println!(
|
||||
" DELETE FROM conversation_states WHERE response_id = 'manual_test_verification';"
|
||||
);
|
||||
|
||||
// DON'T cleanup - leave it for manual verification
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,13 +1,11 @@
|
|||
use bytes::Bytes;
|
||||
use flate2::read::GzDecoder;
|
||||
use hermesllm::apis::openai_responses::{
|
||||
InputItem, OutputItem, ResponsesAPIStreamEvent,
|
||||
};
|
||||
use hermesllm::apis::openai_responses::{InputItem, OutputItem, ResponsesAPIStreamEvent};
|
||||
use hermesllm::apis::streaming_shapes::sse::SseStreamIter;
|
||||
use hermesllm::transforms::response::output_to_input::outputs_to_inputs;
|
||||
use std::io::Read;
|
||||
use std::sync::Arc;
|
||||
use tracing::{info, debug, warn};
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
use crate::handlers::utils::StreamProcessor;
|
||||
use crate::state::{OpenAIConversationState, StateStorage};
|
||||
|
|
@ -53,6 +51,7 @@ pub struct ResponsesStateProcessor<P: StreamProcessor> {
|
|||
}
|
||||
|
||||
impl<P: StreamProcessor> ResponsesStateProcessor<P> {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn new(
|
||||
inner: P,
|
||||
storage: Arc<dyn StateStorage>,
|
||||
|
|
@ -139,20 +138,19 @@ impl<P: StreamProcessor> ResponsesStateProcessor<P> {
|
|||
for event in sse_iter {
|
||||
// Only process data lines (skip event-only lines)
|
||||
if let Some(data_str) = &event.data {
|
||||
// Try to parse as ResponsesAPIStreamEvent
|
||||
if let Ok(stream_event) = serde_json::from_str::<ResponsesAPIStreamEvent>(data_str) {
|
||||
// Check if this is a ResponseCompleted event
|
||||
if let ResponsesAPIStreamEvent::ResponseCompleted { response, .. } = stream_event {
|
||||
info!(
|
||||
"[PLANO_REQ_ID:{}] | STATE_PROCESSOR | Captured streaming response.completed: response_id={}, output_items={}",
|
||||
self.request_id,
|
||||
response.id,
|
||||
response.output.len()
|
||||
);
|
||||
self.response_id = Some(response.id.clone());
|
||||
self.output_items = Some(response.output.clone());
|
||||
return; // Found what we need, exit early
|
||||
}
|
||||
// Try to parse as ResponsesAPIStreamEvent and check if it's a ResponseCompleted event
|
||||
if let Ok(ResponsesAPIStreamEvent::ResponseCompleted { response, .. }) =
|
||||
serde_json::from_str::<ResponsesAPIStreamEvent>(data_str)
|
||||
{
|
||||
info!(
|
||||
"[PLANO_REQ_ID:{}] | STATE_PROCESSOR | Captured streaming response.completed: response_id={}, output_items={}",
|
||||
self.request_id,
|
||||
response.id,
|
||||
response.output.len()
|
||||
);
|
||||
self.response_id = Some(response.id.clone());
|
||||
self.output_items = Some(response.output.clone());
|
||||
return; // Found what we need, exit early
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -172,7 +170,9 @@ impl<P: StreamProcessor> ResponsesStateProcessor<P> {
|
|||
let decompressed = self.decompress_buffer();
|
||||
|
||||
// Parse complete JSON response
|
||||
match serde_json::from_slice::<hermesllm::apis::openai_responses::ResponsesAPIResponse>(&decompressed) {
|
||||
match serde_json::from_slice::<hermesllm::apis::openai_responses::ResponsesAPIResponse>(
|
||||
&decompressed,
|
||||
) {
|
||||
Ok(response) => {
|
||||
info!(
|
||||
"[PLANO_REQ_ID:{}] | STATE_PROCESSOR | Captured non-streaming response: response_id={}, output_items={}",
|
||||
|
|
|
|||
|
|
@ -2,11 +2,9 @@
|
|||
///
|
||||
/// This module defines standard attribute keys following OTEL semantic conventions.
|
||||
/// See: https://opentelemetry.io/docs/specs/semconv/
|
||||
|
||||
// =============================================================================
|
||||
// Span Attributes - HTTP
|
||||
// =============================================================================
|
||||
|
||||
/// Semantic conventions for HTTP-related span attributes
|
||||
pub mod http {
|
||||
/// HTTP request method
|
||||
|
|
|
|||
|
|
@ -1,3 +1,3 @@
|
|||
mod constants;
|
||||
|
||||
pub use constants::{OperationNameBuilder, operation_component, http, llm, error, routing};
|
||||
pub use constants::{error, http, llm, operation_component, routing, OperationNameBuilder};
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue