diff --git a/crates/brightstaff/src/handlers/agent_chat_completions.rs b/crates/brightstaff/src/handlers/agent_chat_completions.rs index fb7db45a..2e8e0961 100644 --- a/crates/brightstaff/src/handlers/agent_chat_completions.rs +++ b/crates/brightstaff/src/handlers/agent_chat_completions.rs @@ -1,10 +1,13 @@ use std::sync::Arc; use bytes::Bytes; -use hermesllm::apis::openai::ChatCompletionsRequest; +use hermesllm::apis::OpenAIMessage; +use hermesllm::clients::SupportedAPIsFromClient; +use hermesllm::ProviderRequestType; use http_body_util::combinators::BoxBody; use http_body_util::BodyExt; use hyper::{Request, Response}; +use serde::ser::Error as SerError; use tracing::{debug, info, warn}; use super::agent_selector::{AgentSelectionError, AgentSelector}; @@ -35,7 +38,15 @@ pub async fn agent_chat( listeners: Arc>>, trace_collector: Arc, ) -> Result>, hyper::Error> { - match handle_agent_chat(request, router_service, agents_list, listeners, trace_collector).await { + match handle_agent_chat( + request, + router_service, + agents_list, + listeners, + trace_collector, + ) + .await + { Ok(response) => Ok(response), Err(err) => { // Check if this is a client error from the pipeline that should be cascaded @@ -134,6 +145,13 @@ async fn handle_agent_chat( info!("Handling request for listener: {}", listener.name); // Parse request body + let request_path = request + .uri() + .path() + .to_string() + .strip_prefix("/agents") + .unwrap() + .to_string(); let request_headers = request.headers().clone(); let chat_request_bytes = request.collect().await?.to_bytes(); @@ -142,15 +160,36 @@ async fn handle_agent_chat( String::from_utf8_lossy(&chat_request_bytes) ); - let chat_completions_request: ChatCompletionsRequest = - serde_json::from_slice(&chat_request_bytes).map_err(|err| { - warn!( - "Failed to parse request body as ChatCompletionsRequest: {}", - err - ); - AgentFilterChainError::RequestParsing(err) + // Determine the API type from the endpoint + let api_type = + SupportedAPIsFromClient::from_endpoint(request_path.as_str()).ok_or_else(|| { + let err_msg = format!("Unsupported endpoint: {}", request_path); + warn!("{}", err_msg); + AgentFilterChainError::RequestParsing(serde_json::Error::custom(err_msg)) })?; + let client_request = match ProviderRequestType::try_from((&chat_request_bytes[..], &api_type)) { + Ok(request) => request, + Err(err) => { + warn!("Failed to parse request as ProviderRequestType: {}", err); + let err_msg = format!("Failed to parse request: {}", err); + return Err(AgentFilterChainError::RequestParsing( + serde_json::Error::custom(err_msg), + )); + } + }; + + let message: Vec = client_request.get_message_history(); + + // let chat_completions_request: ChatCompletionsRequest = + // serde_json::from_slice(&chat_request_bytes).map_err(|err| { + // warn!( + // "Failed to parse request body as ChatCompletionsRequest: {}", + // err + // ); + // AgentFilterChainError::RequestParsing(err) + // })?; + // Extract trace parent for routing let trace_parent = request_headers .iter() @@ -166,11 +205,7 @@ async fn handle_agent_chat( // Select appropriate agent using arch router llm model let selected_agent = agent_selector - .select_agent( - &chat_completions_request.messages, - &listener, - trace_parent, - ) + .select_agent(&message, &listener, trace_parent) .await?; debug!("Processing agent pipeline: {}", selected_agent.id); @@ -178,7 +213,7 @@ async fn handle_agent_chat( // Process the filter chain let chat_history = pipeline_processor .process_filter_chain( - &chat_completions_request.messages, + &message, &selected_agent, &agent_map, &request_headers, @@ -196,7 +231,7 @@ async fn handle_agent_chat( let llm_response = pipeline_processor .invoke_terminal_agent( &chat_history, - &chat_completions_request, + client_request, terminal_agent, &request_headers, ) diff --git a/crates/brightstaff/src/handlers/pipeline_processor.rs b/crates/brightstaff/src/handlers/pipeline_processor.rs index 6c763679..8c11838c 100644 --- a/crates/brightstaff/src/handlers/pipeline_processor.rs +++ b/crates/brightstaff/src/handlers/pipeline_processor.rs @@ -3,7 +3,8 @@ use std::collections::HashMap; use common::configuration::{Agent, AgentFilterChain}; use common::consts::{ARCH_UPSTREAM_HOST_HEADER, ENVOY_RETRY_HEADER}; use common::traces::{SpanBuilder, SpanKind}; -use hermesllm::apis::openai::{ChatCompletionsRequest, Message}; +use hermesllm::{ProviderRequest, ProviderRequestType}; +use hermesllm::apis::openai::{Message}; use hyper::header::HeaderMap; use opentelemetry::trace::TraceContextExt; use tracing::{debug, info, warn}; @@ -468,14 +469,15 @@ impl PipelineProcessor { pub async fn invoke_terminal_agent( &self, messages: &[Message], - original_request: &ChatCompletionsRequest, + mut original_request: ProviderRequestType, terminal_agent: &Agent, request_headers: &HeaderMap, ) -> Result { - let mut request = original_request.clone(); - request.messages = messages.to_vec(); + // let mut request = original_request.clone(); + original_request.set_messages(messages); - let request_body = serde_json::to_string(&request)?; + let request_body = ProviderRequestType::to_bytes(&original_request).unwrap(); + // let request_body = serde_json::to_string(&request)?; debug!("Sending request to terminal agent {}", terminal_agent.id); let mut agent_headers = request_headers.clone(); diff --git a/crates/brightstaff/src/main.rs b/crates/brightstaff/src/main.rs index f9b7b8c5..f14e3469 100644 --- a/crates/brightstaff/src/main.rs +++ b/crates/brightstaff/src/main.rs @@ -1,7 +1,7 @@ use brightstaff::handlers::agent_chat_completions::agent_chat; +use brightstaff::handlers::function_calling::function_calling_chat_handler; use brightstaff::handlers::llm::llm_chat; use brightstaff::handlers::models::list_models; -use brightstaff::handlers::function_calling::{function_calling_chat_handler}; use brightstaff::router::llm_router::RouterService; use brightstaff::utils::tracing::init_tracer; use bytes::Bytes; @@ -105,20 +105,23 @@ async fn main() -> Result<(), Box> { info!("Tracing configuration found in arch_config.yaml"); Some(true) } else { - info!("No tracing configuration in arch_config.yaml, will check OTEL_TRACING_ENABLED env var"); + info!( + "No tracing configuration in arch_config.yaml, will check OTEL_TRACING_ENABLED env var" + ); None }; let trace_collector = Arc::new(TraceCollector::new(tracing_enabled)); let _flusher_handle = trace_collector.clone().start_background_flusher(); - loop { let (stream, _) = listener.accept().await?; let peer_addr = stream.peer_addr()?; let io = TokioIo::new(stream); let router_service: Arc = Arc::clone(&router_service); - let model_aliases: Arc>> = Arc::clone(&model_aliases); + let model_aliases: Arc< + Option>, + > = Arc::clone(&model_aliases); let llm_provider_url = llm_provider_url.clone(); let llm_providers = llm_providers.clone(); @@ -136,18 +139,18 @@ async fn main() -> Result<(), Box> { let trace_collector = trace_collector.clone(); async move { - match (req.method(), req.uri().path()) { - (&Method::POST, CHAT_COMPLETIONS_PATH | MESSAGES_PATH | OPENAI_RESPONSES_API_PATH) => { - let fully_qualified_url = - format!("{}{}", llm_provider_url, req.uri().path()); - llm_chat(req, router_service, fully_qualified_url, model_aliases, llm_providers, trace_collector) - .with_context(parent_cx) - .await - } - (&Method::POST, "/agents/v1/chat/completions") => { - let fully_qualified_url = - format!("{}{}", llm_provider_url, req.uri().path()); - agent_chat( + let path = req.uri().path(); + + // Check if path starts with /agents + if path.starts_with("/agents") { + // Check if it matches one of the agent API paths + let stripped_path = path.strip_prefix("/agents").unwrap(); + if matches!( + stripped_path, + CHAT_COMPLETIONS_PATH | MESSAGES_PATH | OPENAI_RESPONSES_API_PATH + ) { + let fully_qualified_url = format!("{}{}", llm_provider_url, stripped_path); + return agent_chat( req, router_service, fully_qualified_url, @@ -156,6 +159,26 @@ async fn main() -> Result<(), Box> { trace_collector, ) .with_context(parent_cx) + .await; + } + } + + match (req.method(), path) { + ( + &Method::POST, + CHAT_COMPLETIONS_PATH | MESSAGES_PATH | OPENAI_RESPONSES_API_PATH, + ) => { + let fully_qualified_url = + format!("{}{}", llm_provider_url, req.uri().path()); + llm_chat( + req, + router_service, + fully_qualified_url, + model_aliases, + llm_providers, + trace_collector, + ) + .with_context(parent_cx) .await } diff --git a/crates/hermesllm/src/providers/request.rs b/crates/hermesllm/src/providers/request.rs index eb8f0788..72246f39 100644 --- a/crates/hermesllm/src/providers/request.rs +++ b/crates/hermesllm/src/providers/request.rs @@ -49,6 +49,353 @@ pub trait ProviderRequest: Send + Sync { fn get_temperature(&self) -> Option; } +impl ProviderRequestType { + /// Get message history as OpenAI Message format + /// This is useful for processing chat history across different provider formats + pub fn get_message_history(&self) -> Vec { + use crate::apis::openai::{Message, MessageContent, Role}; + + match self { + Self::ChatCompletionsRequest(r) => r.messages.clone(), + Self::MessagesRequest(r) => { + // Convert Anthropic messages to OpenAI format + let mut openai_messages = Vec::new(); + + // Add system prompt as system message if present + if let Some(system) = &r.system { + openai_messages.push(system.clone().into()); + } + + // Convert each Anthropic message to OpenAI format + for msg in &r.messages { + if let Ok(converted_msgs) = TryInto::>::try_into(msg.clone()) { + openai_messages.extend(converted_msgs); + } + } + + openai_messages + } + Self::BedrockConverse(r) => { + // Convert Bedrock messages to OpenAI format + let mut openai_messages = Vec::new(); + + // Add system messages if present + if let Some(system) = &r.system { + for sys_block in system { + match sys_block { + crate::apis::amazon_bedrock::SystemContentBlock::Text { text } => { + openai_messages.push(Message { + role: Role::System, + content: MessageContent::Text(text.clone()), + name: None, + tool_calls: None, + tool_call_id: None, + }); + } + _ => {} // Skip other system content types + } + } + } + + // Convert conversation messages + if let Some(messages) = &r.messages { + for msg in messages { + let role = match msg.role { + crate::apis::amazon_bedrock::ConversationRole::User => Role::User, + crate::apis::amazon_bedrock::ConversationRole::Assistant => Role::Assistant, + }; + + // Extract text from content blocks + let content = msg.content.iter() + .filter_map(|block| { + if let crate::apis::amazon_bedrock::ContentBlock::Text { text } = block { + Some(text.clone()) + } else { + None + } + }) + .collect::>() + .join("\n"); + + openai_messages.push(Message { + role, + content: MessageContent::Text(content), + name: None, + tool_calls: None, + tool_call_id: None, + }); + } + } + + openai_messages + } + Self::BedrockConverseStream(r) => { + // Same as BedrockConverse + let mut openai_messages = Vec::new(); + + if let Some(system) = &r.system { + for sys_block in system { + match sys_block { + crate::apis::amazon_bedrock::SystemContentBlock::Text { text } => { + openai_messages.push(Message { + role: Role::System, + content: MessageContent::Text(text.clone()), + name: None, + tool_calls: None, + tool_call_id: None, + }); + } + _ => {} // Skip other system content types + } + } + } + + if let Some(messages) = &r.messages { + for msg in messages { + let role = match msg.role { + crate::apis::amazon_bedrock::ConversationRole::User => Role::User, + crate::apis::amazon_bedrock::ConversationRole::Assistant => Role::Assistant, + }; + + let content = msg.content.iter() + .filter_map(|block| { + if let crate::apis::amazon_bedrock::ContentBlock::Text { text } = block { + Some(text.clone()) + } else { + None + } + }) + .collect::>() + .join("\n"); + + openai_messages.push(Message { + role, + content: MessageContent::Text(content), + name: None, + tool_calls: None, + tool_call_id: None, + }); + } + } + + openai_messages + } + Self::ResponsesAPIRequest(r) => { + // Convert ResponsesAPIRequest input to a user message + let mut openai_messages = Vec::new(); + + // Add instructions as system message if present + if let Some(instructions) = &r.instructions { + openai_messages.push(Message { + role: Role::System, + content: MessageContent::Text(instructions.clone()), + name: None, + tool_calls: None, + tool_call_id: None, + }); + } + + // Convert input to messages + use crate::apis::openai_responses::{InputParam, InputItem}; + match &r.input { + InputParam::Text(text) => { + openai_messages.push(Message { + role: Role::User, + content: MessageContent::Text(text.clone()), + name: None, + tool_calls: None, + tool_call_id: None, + }); + } + InputParam::Items(items) => { + for item in items { + match item { + InputItem::Message(msg) => { + // Convert message role + let role = match msg.role { + crate::apis::openai_responses::MessageRole::User => Role::User, + crate::apis::openai_responses::MessageRole::Assistant => Role::Assistant, + crate::apis::openai_responses::MessageRole::System => Role::System, + crate::apis::openai_responses::MessageRole::Developer => Role::System, // Map developer to system + }; + + // Extract text from message content + let content = msg.content.iter() + .filter_map(|c| { + if let crate::apis::openai_responses::InputContent::InputText { text } = c { + Some(text.clone()) + } else { + None + } + }) + .collect::>() + .join("\n"); + + openai_messages.push(Message { + role, + content: MessageContent::Text(content), + name: None, + tool_calls: None, + tool_call_id: None, + }); + } + } + } + } + } + + openai_messages + } + } + } + + /// Set message history from OpenAI Message format + /// This converts OpenAI messages to the appropriate format for each provider type + pub fn set_messages(&mut self, messages: &[crate::apis::openai::Message]) { + match self { + Self::ChatCompletionsRequest(r) => { + r.messages = messages.to_vec(); + } + Self::MessagesRequest(r) => { + // Convert OpenAI messages to Anthropic format + // Separate system messages from regular messages + let mut system_messages = Vec::new(); + let mut regular_messages = Vec::new(); + + for msg in messages { + if msg.role == crate::apis::openai::Role::System { + system_messages.push(msg.clone()); + } else { + regular_messages.push(msg.clone()); + } + } + + // Set system prompt if there are system messages + if !system_messages.is_empty() { + // Combine all system messages into one + let system_text = system_messages.iter() + .filter_map(|msg| { + if let crate::apis::openai::MessageContent::Text(text) = &msg.content { + Some(text.as_str()) + } else { + None + } + }) + .collect::>() + .join("\n"); + + r.system = Some(crate::apis::anthropic::MessagesSystemPrompt::Single(system_text)); + } + + // Convert regular messages + r.messages = regular_messages.iter() + .filter_map(|msg| { + msg.clone().try_into().ok() + }) + .collect(); + } + Self::BedrockConverse(r) | Self::BedrockConverseStream(r) => { + // Convert OpenAI messages to Bedrock format + use crate::apis::amazon_bedrock::{ContentBlock, ConversationRole, SystemContentBlock}; + + let mut system_blocks = Vec::new(); + let mut bedrock_messages = Vec::new(); + + for msg in messages { + match msg.role { + crate::apis::openai::Role::System => { + if let crate::apis::openai::MessageContent::Text(text) = &msg.content { + system_blocks.push(SystemContentBlock::Text { text: text.clone() }); + } + } + crate::apis::openai::Role::User | crate::apis::openai::Role::Assistant => { + let role = match msg.role { + crate::apis::openai::Role::User => ConversationRole::User, + crate::apis::openai::Role::Assistant => ConversationRole::Assistant, + _ => continue, + }; + + let content = if let crate::apis::openai::MessageContent::Text(text) = &msg.content { + vec![ContentBlock::Text { text: text.clone() }] + } else { + vec![] + }; + + bedrock_messages.push(crate::apis::amazon_bedrock::Message { + role, + content, + }); + } + _ => {} + } + } + + if !system_blocks.is_empty() { + r.system = Some(system_blocks); + } + r.messages = Some(bedrock_messages); + } + Self::ResponsesAPIRequest(r) => { + // For ResponsesAPI, we need to convert messages back to input format + // Extract system messages as instructions + let system_text = messages.iter() + .filter(|msg| msg.role == crate::apis::openai::Role::System) + .filter_map(|msg| { + if let crate::apis::openai::MessageContent::Text(text) = &msg.content { + Some(text.as_str()) + } else { + None + } + }) + .collect::>() + .join("\n"); + + if !system_text.is_empty() { + r.instructions = Some(system_text); + } + + // Convert user/assistant messages to InputParam + // For simplicity, we'll use the last user message as the input + // or combine all non-system messages + let input_messages: Vec<_> = messages.iter() + .filter(|msg| msg.role != crate::apis::openai::Role::System) + .collect(); + + if !input_messages.is_empty() { + // If there's only one message, use Text format + if input_messages.len() == 1 { + if let crate::apis::openai::MessageContent::Text(text) = &input_messages[0].content { + r.input = crate::apis::openai_responses::InputParam::Text(text.clone()); + } + } else { + // Multiple messages - combine them as text for now + // A more sophisticated approach would use InputParam::Items + let combined_text = input_messages.iter() + .filter_map(|msg| { + if let crate::apis::openai::MessageContent::Text(text) = &msg.content { + Some(format!("{}: {}", + match msg.role { + crate::apis::openai::Role::User => "User", + crate::apis::openai::Role::Assistant => "Assistant", + _ => "Unknown", + }, + text + )) + } else { + None + } + }) + .collect::>() + .join("\n"); + + r.input = crate::apis::openai_responses::InputParam::Text(combined_text); + } + } + } + } + } +} + impl ProviderRequest for ProviderRequestType { fn model(&self) -> &str { match self { @@ -934,4 +1281,131 @@ mod tests { .message .contains("OpenAI ChatCompletions, Anthropic Messages, and OpenAI Responses")); } + + #[test] + fn test_get_message_history_chat_completions() { + use crate::apis::openai::{Message, MessageContent, Role}; + + let chat_req = ChatCompletionsRequest { + model: "gpt-4".to_string(), + messages: vec![ + Message { + role: Role::System, + content: MessageContent::Text("You are helpful".to_string()), + name: None, + tool_calls: None, + tool_call_id: None, + }, + Message { + role: Role::User, + content: MessageContent::Text("Hello!".to_string()), + name: None, + tool_calls: None, + tool_call_id: None, + }, + ], + ..Default::default() + }; + + let provider_req = ProviderRequestType::ChatCompletionsRequest(chat_req); + let messages = provider_req.get_message_history(); + + assert_eq!(messages.len(), 2); + assert_eq!(messages[0].role, Role::System); + assert_eq!(messages[1].role, Role::User); + } + + #[test] + fn test_get_message_history_anthropic_messages() { + use crate::apis::anthropic::{ + MessagesMessage, MessagesMessageContent, MessagesRequest, MessagesRole, + MessagesSystemPrompt, + }; + + let anthropic_req = MessagesRequest { + model: "claude-3-sonnet".to_string(), + messages: vec![MessagesMessage { + role: MessagesRole::User, + content: MessagesMessageContent::Single("Hello!".to_string()), + }], + system: Some(MessagesSystemPrompt::Single( + "You are helpful".to_string(), + )), + max_tokens: 100, + container: None, + mcp_servers: None, + metadata: None, + service_tier: None, + thinking: None, + temperature: None, + top_p: None, + top_k: None, + stream: None, + stop_sequences: None, + tools: None, + tool_choice: None, + }; + + let provider_req = ProviderRequestType::MessagesRequest(anthropic_req); + let messages = provider_req.get_message_history(); + + // Should have system message + user message + assert_eq!(messages.len(), 2); + assert_eq!( + messages[0].role, + crate::apis::openai::Role::System + ); + assert_eq!( + messages[1].role, + crate::apis::openai::Role::User + ); + } + + #[test] + fn test_get_message_history_responses_api() { + use crate::apis::openai_responses::{InputParam, ResponsesAPIRequest}; + + let responses_req = ResponsesAPIRequest { + model: "gpt-4o".to_string(), + input: InputParam::Text("Hello, world!".to_string()), + instructions: Some("Be helpful".to_string()), + temperature: None, + max_output_tokens: None, + stream: None, + metadata: None, + tools: None, + tool_choice: None, + parallel_tool_calls: None, + modalities: None, + user: None, + store: None, + reasoning_effort: None, + include: None, + audio: None, + text: None, + service_tier: None, + top_p: None, + top_logprobs: None, + stream_options: None, + truncation: None, + conversation: None, + previous_response_id: None, + max_tool_calls: None, + background: None, + }; + + let provider_req = ProviderRequestType::ResponsesAPIRequest(responses_req); + let messages = provider_req.get_message_history(); + + // Should have system message (instructions) + user message (input) + assert_eq!(messages.len(), 2); + assert_eq!( + messages[0].role, + crate::apis::openai::Role::System + ); + assert_eq!( + messages[1].role, + crate::apis::openai::Role::User + ); + } }