From ca95ffb63d5e70713452571d4fec9335436d7a70 Mon Sep 17 00:00:00 2001 From: Adil Hafeez Date: Thu, 25 Dec 2025 21:08:37 -0800 Subject: [PATCH] cargo clippy (#660) --- .pre-commit-config.yaml | 7 +- .../src/handlers/agent_chat_completions.rs | 44 +- .../src/handlers/agent_selector.rs | 2 +- .../src/handlers/function_calling.rs | 556 +++++++++++------- .../src/handlers/integration_tests.rs | 7 +- crates/brightstaff/src/handlers/jsonrpc.rs | 44 +- crates/brightstaff/src/handlers/llm.rs | 91 ++- crates/brightstaff/src/handlers/mod.rs | 8 +- .../src/handlers/pipeline_processor.rs | 17 +- .../src/handlers/response_handler.rs | 8 +- .../brightstaff/src/handlers/router_chat.rs | 37 +- crates/brightstaff/src/handlers/utils.rs | 19 +- crates/brightstaff/src/main.rs | 81 +-- .../src/router/orchestrator_model_v1.rs | 127 ++-- .../src/router/plano_orchestrator.rs | 5 +- crates/brightstaff/src/state/memory.rs | 53 +- crates/brightstaff/src/state/mod.rs | 17 +- crates/brightstaff/src/state/postgresql.rs | 35 +- .../src/state/response_state_processor.rs | 38 +- crates/brightstaff/src/tracing/constants.rs | 2 - crates/brightstaff/src/tracing/mod.rs | 2 +- crates/common/src/configuration.rs | 63 +- crates/common/src/consts.rs | 2 +- crates/common/src/lib.rs | 2 +- crates/common/src/routing.rs | 3 +- crates/common/src/traces/collector.rs | 13 +- crates/common/src/traces/constants.rs | 1 - crates/common/src/traces/mod.rs | 14 +- .../src/traces/resource_span_builder.rs | 22 +- crates/common/src/traces/span_builder.rs | 9 +- .../src/traces/tests/mock_otel_collector.rs | 9 +- .../traces/tests/trace_integration_test.rs | 72 ++- crates/common/src/tracing.rs | 9 +- crates/hermesllm/src/apis/amazon_bedrock.rs | 79 +-- crates/hermesllm/src/apis/anthropic.rs | 24 +- crates/hermesllm/src/apis/openai.rs | 49 +- crates/hermesllm/src/apis/openai_responses.rs | 97 ++- .../amazon_bedrock_binary_frame.rs | 5 +- .../anthropic_streaming_buffer.rs | 211 +++++-- .../chat_completions_streaming_buffer.rs | 8 +- .../src/apis/streaming_shapes/mod.rs | 4 +- .../passthrough_streaming_buffer.rs | 18 +- .../responses_api_streaming_buffer.rs | 273 +++++++-- .../src/apis/streaming_shapes/sse.rs | 47 +- .../streaming_shapes/sse_chunk_processor.rs | 93 ++- crates/hermesllm/src/clients/endpoints.rs | 42 +- crates/hermesllm/src/lib.rs | 18 +- crates/hermesllm/src/providers/id.rs | 7 +- crates/hermesllm/src/providers/request.rs | 39 +- crates/hermesllm/src/providers/response.rs | 80 +-- .../src/providers/streaming_response.rs | 148 +++-- crates/hermesllm/src/transforms/lib.rs | 6 +- .../src/transforms/request/from_anthropic.rs | 36 +- .../src/transforms/request/from_openai.rs | 120 ++-- .../transforms/response/output_to_input.rs | 28 +- .../src/transforms/response/to_anthropic.rs | 4 +- .../src/transforms/response/to_openai.rs | 42 +- .../to_anthropic_streaming.rs | 35 +- .../response_streaming/to_openai_streaming.rs | 66 ++- crates/llm_gateway/src/stream_context.rs | 8 +- crates/prompt_gateway/src/http_context.rs | 10 +- crates/prompt_gateway/src/stream_context.rs | 35 +- 62 files changed, 1864 insertions(+), 1187 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5dd758ce..42b43943 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -13,19 +13,22 @@ repos: name: cargo-fmt language: system types: [file, rust] - entry: bash -c "cd crates/llm_gateway && cargo fmt" + entry: bash -c "cd crates && cargo fmt --all -- --check" + pass_filenames: false - id: cargo-clippy name: cargo-clippy language: system types: [file, rust] - entry: bash -c "cd crates/llm_gateway && cargo clippy --all" + entry: bash -c "cd crates && cargo clippy --locked --offline --all-targets --all-features -- -D warnings || cargo clippy --locked --all-targets --all-features -- -D warnings" + pass_filenames: false - id: cargo-test name: cargo-test language: system types: [file, rust] entry: bash -c "cd crates && cargo test --lib" + pass_filenames: false - repo: https://github.com/psf/black rev: 23.1.0 diff --git a/crates/brightstaff/src/handlers/agent_chat_completions.rs b/crates/brightstaff/src/handlers/agent_chat_completions.rs index 0c1232a2..df358190 100644 --- a/crates/brightstaff/src/handlers/agent_chat_completions.rs +++ b/crates/brightstaff/src/handlers/agent_chat_completions.rs @@ -3,7 +3,7 @@ use std::time::{Instant, SystemTime}; use bytes::Bytes; use common::consts::TRACE_PARENT_HEADER; -use common::traces::{SpanBuilder, SpanKind, parse_traceparent, generate_random_span_id}; +use common::traces::{generate_random_span_id, parse_traceparent, SpanBuilder, SpanKind}; use hermesllm::apis::OpenAIMessage; use hermesllm::clients::SupportedAPIsFromClient; use hermesllm::providers::request::ProviderRequest; @@ -18,7 +18,7 @@ use super::agent_selector::{AgentSelectionError, AgentSelector}; use super::pipeline_processor::{PipelineError, PipelineProcessor}; use super::response_handler::ResponseHandler; use crate::router::plano_orchestrator::OrchestratorService; -use crate::tracing::{OperationNameBuilder, operation_component, http}; +use crate::tracing::{http, operation_component, OperationNameBuilder}; /// Main errors for agent chat completions #[derive(Debug, thiserror::Error)] @@ -61,7 +61,6 @@ pub async fn agent_chat( body, }) = &err { - warn!( "Client error from agent '{}' (HTTP {}): {}", agent, status, body @@ -77,8 +76,8 @@ pub async fn agent_chat( let json_string = error_json.to_string(); let mut response = Response::new(ResponseHandler::create_full_body(json_string)); - *response.status_mut() = hyper::StatusCode::from_u16(*status) - .unwrap_or(hyper::StatusCode::BAD_REQUEST); + *response.status_mut() = + hyper::StatusCode::from_u16(*status).unwrap_or(hyper::StatusCode::BAD_REQUEST); response.headers_mut().insert( hyper::header::CONTENT_TYPE, "application/json".parse().unwrap(), @@ -234,8 +233,18 @@ async fn handle_agent_chat( .with_attribute(http::TARGET, "/agents/select") .with_attribute("selection.listener", listener.name.clone()) .with_attribute("selection.agent_count", selected_agents.len().to_string()) - .with_attribute("selection.agents", selected_agents.iter().map(|a| a.id.as_str()).collect::>().join(",")) - .with_attribute("duration_ms", format!("{:.2}", selection_elapsed.as_secs_f64() * 1000.0)); + .with_attribute( + "selection.agents", + selected_agents + .iter() + .map(|a| a.id.as_str()) + .collect::>() + .join(","), + ) + .with_attribute( + "duration_ms", + format!("{:.2}", selection_elapsed.as_secs_f64() * 1000.0), + ); if !trace_id.is_empty() { selection_span_builder = selection_span_builder.with_trace_id(trace_id.clone()); @@ -318,8 +327,14 @@ async fn handle_agent_chat( .with_attribute(http::METHOD, "POST") .with_attribute(http::TARGET, full_path) .with_attribute("agent.name", agent_name.clone()) - .with_attribute("agent.sequence", format!("{}/{}", agent_index + 1, agent_count)) - .with_attribute("duration_ms", format!("{:.2}", agent_elapsed.as_secs_f64() * 1000.0)); + .with_attribute( + "agent.sequence", + format!("{}/{}", agent_index + 1, agent_count), + ) + .with_attribute( + "duration_ms", + format!("{:.2}", agent_elapsed.as_secs_f64() * 1000.0), + ); if !trace_id.is_empty() { span_builder = span_builder.with_trace_id(trace_id.clone()); @@ -333,7 +348,10 @@ async fn handle_agent_chat( // If this is the last agent, return the streaming response if is_last_agent { - info!("Completed agent chain, returning response from last agent: {}", agent_name); + info!( + "Completed agent chain, returning response from last agent: {}", + agent_name + ); return response_handler .create_streaming_response(llm_response) .await @@ -341,7 +359,10 @@ async fn handle_agent_chat( } // For intermediate agents, collect the full response and pass to next agent - debug!("Collecting response from intermediate agent: {}", agent_name); + debug!( + "Collecting response from intermediate agent: {}", + agent_name + ); let response_text = response_handler.collect_full_response(llm_response).await?; info!( @@ -364,7 +385,6 @@ async fn handle_agent_chat( }); current_messages.push(last_message); - } // This should never be reached since we return in the last agent iteration diff --git a/crates/brightstaff/src/handlers/agent_selector.rs b/crates/brightstaff/src/handlers/agent_selector.rs index e26f6391..78fbc654 100644 --- a/crates/brightstaff/src/handlers/agent_selector.rs +++ b/crates/brightstaff/src/handlers/agent_selector.rs @@ -2,7 +2,7 @@ use std::collections::HashMap; use std::sync::Arc; use common::configuration::{ - Agent, AgentFilterChain, Listener, AgentUsagePreference, OrchestrationPreference, + Agent, AgentFilterChain, AgentUsagePreference, Listener, OrchestrationPreference, }; use hermesllm::apis::openai::Message; use tracing::{debug, warn}; diff --git a/crates/brightstaff/src/handlers/function_calling.rs b/crates/brightstaff/src/handlers/function_calling.rs index 295228b3..8f641df6 100644 --- a/crates/brightstaff/src/handlers/function_calling.rs +++ b/crates/brightstaff/src/handlers/function_calling.rs @@ -1,20 +1,18 @@ +use bytes::Bytes; +use eventsource_stream::Eventsource; +use futures::StreamExt; use hermesllm::apis::openai::{ ChatCompletionsRequest, ChatCompletionsResponse, Choice, FinishReason, FunctionCall, Message, MessageContent, ResponseMessage, Role, Tool, ToolCall, Usage, }; +use http_body_util::{combinators::BoxBody, BodyExt, Full}; +use hyper::body::Incoming; +use hyper::{Request, Response, StatusCode}; use serde::{Deserialize, Serialize}; use serde_json::{json, Value}; use std::collections::HashMap; use thiserror::Error; -use tracing::{info, error}; -use futures::StreamExt; -use bytes::Bytes; -use http_body_util::{combinators::BoxBody, BodyExt, Full}; -use hyper::body::Incoming; -use hyper::{Request, Response, StatusCode}; -use eventsource_stream::Eventsource; - - +use tracing::{error, info}; // ============================================================================ // CONSTANTS FOR HALLUCINATION DETECTION @@ -273,17 +271,14 @@ impl ArchFunctionHandler { let mut stack: Vec = Vec::new(); let mut fixed_str = String::new(); - let matching_bracket: HashMap = - [(')', '('), ('}', '{'), (']', '[')] - .iter() - .cloned() - .collect(); - - let opening_bracket: HashMap = matching_bracket + let matching_bracket: HashMap = [(')', '('), ('}', '{'), (']', '[')] .iter() - .map(|(k, v)| (*v, *k)) + .cloned() .collect(); + let opening_bracket: HashMap = + matching_bracket.iter().map(|(k, v)| (*v, *k)).collect(); + for ch in json_str.chars() { if ch == '{' || ch == '[' || ch == '(' { stack.push(ch); @@ -332,12 +327,18 @@ impl ArchFunctionHandler { // Remove markdown code blocks let mut content = content.trim().to_string(); if content.starts_with("```") && content.ends_with("```") { - content = content.trim_start_matches("```").trim_end_matches("```").to_string(); + content = content + .trim_start_matches("```") + .trim_end_matches("```") + .to_string(); if content.starts_with("json") { content = content.trim_start_matches("json").to_string(); } // Trim again after removing code blocks to eliminate internal whitespace - content = content.trim_start_matches(r"\n").trim_end_matches(r"\n").to_string(); + content = content + .trim_start_matches(r"\n") + .trim_end_matches(r"\n") + .to_string(); content = content.trim().to_string(); // Unescape the quotes: \" -> " // The model sometimes returns escaped JSON inside markdown blocks @@ -453,12 +454,12 @@ impl ArchFunctionHandler { /// Helper method to check if a value matches the expected type fn check_value_type(&self, value: &Value, target_type: &str) -> bool { match target_type { - "int" | "integer" => value.is_i64() || value.is_u64(), + "int" | "integer" => value.is_i64() || value.is_u64(), "float" | "number" => value.is_f64() || value.is_i64() || value.is_u64(), - "bool" | "boolean" => value.is_boolean(), - "str" | "string" => value.is_string(), - "list" | "array" => value.is_array(), - "dict" | "object" => value.is_object(), + "bool" | "boolean" => value.is_boolean(), + "str" | "string" => value.is_string(), + "list" | "array" => value.is_array(), + "dict" | "object" => value.is_object(), _ => true, } } @@ -505,15 +506,19 @@ impl ArchFunctionHandler { let func_name = &tool_call.function.name; // Parse arguments as JSON - let func_args: HashMap = match serde_json::from_str(&tool_call.function.arguments) { - Ok(args) => args, - Err(e) => { - verification.is_valid = false; - verification.invalid_tool_call = Some(tool_call.clone()); - verification.error_message = format!("Failed to parse arguments for function '{}': {}", func_name, e); - break; - } - }; + let func_args: HashMap = + match serde_json::from_str(&tool_call.function.arguments) { + Ok(args) => args, + Err(e) => { + verification.is_valid = false; + verification.invalid_tool_call = Some(tool_call.clone()); + verification.error_message = format!( + "Failed to parse arguments for function '{}': {}", + func_name, e + ); + break; + } + }; // Check if function is available if let Some(function_params) = functions.get(func_name) { @@ -541,14 +546,23 @@ impl ArchFunctionHandler { if let Some(properties_obj) = properties.as_object() { for (param_name, param_value) in &func_args { if let Some(param_schema) = properties_obj.get(param_name) { - if let Some(target_type) = param_schema.get("type").and_then(|v| v.as_str()) { - if self.config.support_data_types.contains(&target_type.to_string()) { + if let Some(target_type) = + param_schema.get("type").and_then(|v| v.as_str()) + { + if self + .config + .support_data_types + .contains(&target_type.to_string()) + { // Validate data type using helper method - match self.validate_or_convert_parameter(param_value, target_type) { + match self + .validate_or_convert_parameter(param_value, target_type) + { Ok(is_valid) => { if !is_valid { verification.is_valid = false; - verification.invalid_tool_call = Some(tool_call.clone()); + verification.invalid_tool_call = + Some(tool_call.clone()); verification.error_message = format!( "Parameter `{}` is expected to have the data type `{}`, got incompatible type.", param_name, target_type @@ -558,7 +572,8 @@ impl ArchFunctionHandler { } Err(_) => { verification.is_valid = false; - verification.invalid_tool_call = Some(tool_call.clone()); + verification.invalid_tool_call = + Some(tool_call.clone()); verification.error_message = format!( "Parameter `{}` is expected to have the data type `{}`, got incompatible type.", param_name, target_type @@ -569,7 +584,10 @@ impl ArchFunctionHandler { } else { verification.is_valid = false; verification.invalid_tool_call = Some(tool_call.clone()); - verification.error_message = format!("Data type `{}` is not supported.", target_type); + verification.error_message = format!( + "Data type `{}` is not supported.", + target_type + ); break; } } @@ -598,11 +616,8 @@ impl ArchFunctionHandler { /// Formats the system prompt with tools pub fn format_system_prompt(&self, tools: &[Tool]) -> Result { let tools_str = self.convert_tools(tools)?; - let system_prompt = self - .config - .task_prompt - .replace("{tools}", &tools_str) - + &self.config.format_prompt; + let system_prompt = + self.config.task_prompt.replace("{tools}", &tools_str) + &self.config.format_prompt; Ok(system_prompt) } @@ -665,15 +680,22 @@ impl ArchFunctionHandler { // Strip markdown code blocks if tool_call_msg.starts_with("```") && tool_call_msg.ends_with("```") { - tool_call_msg = tool_call_msg.trim_start_matches("```").trim_end_matches("```").trim().to_string(); + tool_call_msg = tool_call_msg + .trim_start_matches("```") + .trim_end_matches("```") + .trim() + .to_string(); if tool_call_msg.starts_with("json") { - tool_call_msg = tool_call_msg.trim_start_matches("json").trim().to_string(); + tool_call_msg = + tool_call_msg.trim_start_matches("json").trim().to_string(); } } // Extract function name if let Ok(parsed) = serde_json::from_str::(&tool_call_msg) { - if let Some(tool_calls_arr) = parsed.get("tool_calls").and_then(|v| v.as_array()) { + if let Some(tool_calls_arr) = + parsed.get("tool_calls").and_then(|v| v.as_array()) + { if let Some(first_tool_call) = tool_calls_arr.first() { let func_name = first_tool_call .get("name") @@ -685,8 +707,10 @@ impl ArchFunctionHandler { "result": content, }); - content = format!("\n{}\n", - serde_json::to_string(&tool_response)?); + content = format!( + "\n{}\n", + serde_json::to_string(&tool_response)? + ); } } } @@ -717,7 +741,7 @@ impl ArchFunctionHandler { if let Some(instruction) = extra_instruction { if let Some(last) = processed_messages.last_mut() { if let MessageContent::Text(content) = &mut last.content { - content.push_str("\n"); + content.push('\n'); content.push_str(instruction); } } @@ -750,13 +774,11 @@ impl ArchFunctionHandler { for i in (conversation_idx..messages.len()).rev() { if let MessageContent::Text(content) = &messages[i].content { num_tokens += content.len() / 4; - if num_tokens >= max_tokens { - if messages[i].role == Role::User { - // Set message_idx to current position and break - // This matches Python's behavior where message_idx is set before break - message_idx = i; - break; - } + if num_tokens >= max_tokens && messages[i].role == Role::User { + // Set message_idx to current position and break + // This matches Python's behavior where message_idx is set before break + message_idx = i; + break; } } // Only update message_idx if we haven't hit the token limit yet @@ -789,7 +811,11 @@ impl ArchFunctionHandler { } /// Helper to create a request with VLLM-specific parameters - fn create_request_with_extra_body(&self, messages: Vec, stream: bool) -> ChatCompletionsRequest { + fn create_request_with_extra_body( + &self, + messages: Vec, + stream: bool, + ) -> ChatCompletionsRequest { ChatCompletionsRequest { model: self.model_name.clone(), messages, @@ -813,24 +839,38 @@ impl ArchFunctionHandler { } /// Makes a streaming request and returns the SSE event stream - async fn make_streaming_request(&self, request: ChatCompletionsRequest) -> Result> + Send>>> { - let request_body = serde_json::to_string(&request) - .map_err(|e| FunctionCallingError::InvalidModelResponse(format!("Failed to serialize request: {}", e)))?; + async fn make_streaming_request( + &self, + request: ChatCompletionsRequest, + ) -> Result< + std::pin::Pin> + Send>>, + > { + let request_body = serde_json::to_string(&request).map_err(|e| { + FunctionCallingError::InvalidModelResponse(format!( + "Failed to serialize request: {}", + e + )) + })?; - let response = self.http_client + let response = self + .http_client .post(&self.endpoint_url) .header("Content-Type", "application/json") .body(request_body) .send() .await - .map_err(|e| FunctionCallingError::HttpError(e))?; + .map_err(FunctionCallingError::HttpError)?; if !response.status().is_success() { let status = response.status(); - let error_text = response.text().await.unwrap_or_else(|_| "Unknown error".to_string()); - return Err(FunctionCallingError::InvalidModelResponse( - format!("HTTP error {}: {}", status, error_text) - )); + let error_text = response + .text() + .await + .unwrap_or_else(|_| "Unknown error".to_string()); + return Err(FunctionCallingError::InvalidModelResponse(format!( + "HTTP error {}: {}", + status, error_text + ))); } // Parse SSE stream @@ -856,38 +896,51 @@ impl ArchFunctionHandler { } /// Makes a non-streaming request and returns the response - async fn make_non_streaming_request(&self, request: ChatCompletionsRequest) -> Result { - let request_body = serde_json::to_string(&request) - .map_err(|e| FunctionCallingError::InvalidModelResponse(format!("Failed to serialize request: {}", e)))?; + async fn make_non_streaming_request( + &self, + request: ChatCompletionsRequest, + ) -> Result { + let request_body = serde_json::to_string(&request).map_err(|e| { + FunctionCallingError::InvalidModelResponse(format!( + "Failed to serialize request: {}", + e + )) + })?; - let response = self.http_client + let response = self + .http_client .post(&self.endpoint_url) .header("Content-Type", "application/json") .body(request_body) .send() .await - .map_err(|e| FunctionCallingError::HttpError(e))?; + .map_err(FunctionCallingError::HttpError)?; if !response.status().is_success() { let status = response.status(); - let error_text = response.text().await.unwrap_or_else(|_| "Unknown error".to_string()); - return Err(FunctionCallingError::InvalidModelResponse( - format!("HTTP error {}: {}", status, error_text) - )); + let error_text = response + .text() + .await + .unwrap_or_else(|_| "Unknown error".to_string()); + return Err(FunctionCallingError::InvalidModelResponse(format!( + "HTTP error {}: {}", + status, error_text + ))); } - let response_text = response.text().await - .map_err(|e| FunctionCallingError::HttpError(e))?; + let response_text = response + .text() + .await + .map_err(FunctionCallingError::HttpError)?; - serde_json::from_str(&response_text) - .map_err(|e| FunctionCallingError::JsonParseError(e)) + serde_json::from_str(&response_text).map_err(FunctionCallingError::JsonParseError) } pub async fn function_calling_chat( &self, request: ChatCompletionsRequest, ) -> Result { - use tracing::{info, error}; + use tracing::{error, info}; info!("[Arch-Function] - ChatCompletion"); @@ -899,10 +952,14 @@ impl ArchFunctionHandler { request.metadata.as_ref(), )?; - info!("[request to arch-fc]: model: {}, messages count: {}", - self.model_name, messages.len()); + info!( + "[request to arch-fc]: model: {}, messages count: {}", + self.model_name, + messages.len() + ); - let use_agent_orchestrator = request.metadata + let use_agent_orchestrator = request + .metadata .as_ref() .and_then(|m| m.get("use_agent_orchestrator")) .and_then(|v| v.as_bool()) @@ -918,89 +975,95 @@ impl ArchFunctionHandler { if use_agent_orchestrator { while let Some(chunk_result) = stream.next().await { - let chunk = chunk_result.map_err(|e| FunctionCallingError::InvalidModelResponse(e))?; + let chunk = chunk_result.map_err(FunctionCallingError::InvalidModelResponse)?; // Extract content from JSON response if let Some(choices) = chunk.get("choices").and_then(|v| v.as_array()) { if let Some(choice) = choices.first() { - if let Some(content) = choice.get("delta") + if let Some(content) = choice + .get("delta") .and_then(|d| d.get("content")) - .and_then(|c| c.as_str()) { + .and_then(|c| c.as_str()) + { model_response.push_str(content); } } } } info!("[Agent Orchestrator]: response received"); - } else { - if let Some(tools) = request.tools.as_ref() { - let mut hallucination_state = HallucinationState::new(tools); - let mut has_tool_calls = None; - let mut has_hallucination = false; + } else if let Some(tools) = request.tools.as_ref() { + let mut hallucination_state = HallucinationState::new(tools); + let mut has_tool_calls = None; + let mut has_hallucination = false; - while let Some(chunk_result) = stream.next().await { - let chunk = chunk_result.map_err(|e| FunctionCallingError::InvalidModelResponse(e))?; + while let Some(chunk_result) = stream.next().await { + let chunk = chunk_result.map_err(FunctionCallingError::InvalidModelResponse)?; - // Extract content and logprobs from JSON response - if let Some(choices) = chunk.get("choices").and_then(|v| v.as_array()) { - if let Some(choice) = choices.first() { - if let Some(content) = choice.get("delta") - .and_then(|d| d.get("content")) - .and_then(|c| c.as_str()) { + // Extract content and logprobs from JSON response + if let Some(choices) = chunk.get("choices").and_then(|v| v.as_array()) { + if let Some(choice) = choices.first() { + if let Some(content) = choice + .get("delta") + .and_then(|d| d.get("content")) + .and_then(|c| c.as_str()) + { + // Extract logprobs + let logprobs: Vec = choice + .get("logprobs") + .and_then(|lp| lp.get("content")) + .and_then(|c| c.as_array()) + .and_then(|arr| arr.first()) + .and_then(|token| token.get("top_logprobs")) + .and_then(|tlp| tlp.as_array()) + .map(|arr| { + arr.iter() + .filter_map(|v| v.get("logprob").and_then(|lp| lp.as_f64())) + .collect() + }) + .unwrap_or_default(); - // Extract logprobs - let logprobs: Vec = choice.get("logprobs") - .and_then(|lp| lp.get("content")) - .and_then(|c| c.as_array()) - .and_then(|arr| arr.first()) - .and_then(|token| token.get("top_logprobs")) - .and_then(|tlp| tlp.as_array()) - .map(|arr| { - arr.iter() - .filter_map(|v| v.get("logprob").and_then(|lp| lp.as_f64())) - .collect() - }) - .unwrap_or_default(); + if hallucination_state + .append_and_check_token_hallucination(content.to_string(), logprobs) + { + has_hallucination = true; + break; + } - if hallucination_state.append_and_check_token_hallucination(content.to_string(), logprobs) { - has_hallucination = true; - break; - } - - if hallucination_state.tokens.len() > 5 && has_tool_calls.is_none() { - let collected_content = hallucination_state.tokens.join(""); - has_tool_calls = Some(collected_content.contains("tool_calls")); - } + if hallucination_state.tokens.len() > 5 && has_tool_calls.is_none() { + let collected_content = hallucination_state.tokens.join(""); + has_tool_calls = Some(collected_content.contains("tool_calls")); } } } } + } - if has_tool_calls == Some(true) && has_hallucination { - info!("[Hallucination]: {}", hallucination_state.error_message); + if has_tool_calls == Some(true) && has_hallucination { + info!("[Hallucination]: {}", hallucination_state.error_message); - let clarify_messages = self.prefill_message(messages.clone(), &self.clarify_prefix); - let clarify_request = self.create_request_with_extra_body(clarify_messages, false); + let clarify_messages = self.prefill_message(messages.clone(), &self.clarify_prefix); + let clarify_request = self.create_request_with_extra_body(clarify_messages, false); - let retry_response = self.make_non_streaming_request(clarify_request).await?; + let retry_response = self.make_non_streaming_request(clarify_request).await?; - if let Some(choice) = retry_response.choices.first() { - if let Some(content) = &choice.message.content { - model_response = content.clone(); - } + if let Some(choice) = retry_response.choices.first() { + if let Some(content) = &choice.message.content { + model_response = content.clone(); } - } else { - model_response = hallucination_state.tokens.join(""); } } else { - while let Some(chunk_result) = stream.next().await { - let chunk = chunk_result.map_err(|e| FunctionCallingError::InvalidModelResponse(e))?; - if let Some(choices) = chunk.get("choices").and_then(|v| v.as_array()) { - if let Some(choice) = choices.first() { - if let Some(content) = choice.get("delta") - .and_then(|d| d.get("content")) - .and_then(|c| c.as_str()) { - model_response.push_str(content); - } + model_response = hallucination_state.tokens.join(""); + } + } else { + while let Some(chunk_result) = stream.next().await { + let chunk = chunk_result.map_err(FunctionCallingError::InvalidModelResponse)?; + if let Some(choices) = chunk.get("choices").and_then(|v| v.as_array()) { + if let Some(choice) = choices.first() { + if let Some(content) = choice + .get("delta") + .and_then(|d| d.get("content")) + .and_then(|c| c.as_str()) + { + model_response.push_str(content); } } } @@ -1009,10 +1072,17 @@ impl ArchFunctionHandler { let response_dict = self.parse_model_response(&model_response); - info!("[arch-fc]: raw model response: {}", response_dict.raw_response); + info!( + "[arch-fc]: raw model response: {}", + response_dict.raw_response + ); // General model response (no intent matched - should route to default target) - let model_message = if response_dict.response.as_ref().map_or(false, |s| !s.is_empty()) { + let model_message = if response_dict + .response + .as_ref() + .is_some_and(|s| !s.is_empty()) + { // When arch-fc returns a "response" field, it means no intent was matched // Return empty content and empty tool_calls so prompt_gateway routes to default target ResponseMessage { @@ -1053,8 +1123,11 @@ impl ArchFunctionHandler { let verification = self.verify_tool_calls(tools, &response_dict.tool_calls); if verification.is_valid { - info!("[Tool calls]: {:?}", - response_dict.tool_calls.iter() + info!( + "[Tool calls]: {:?}", + response_dict + .tool_calls + .iter() .map(|tc| &tc.function) .collect::>() ); @@ -1092,8 +1165,11 @@ impl ArchFunctionHandler { } } } else { - info!("[Tool calls]: {:?}", - response_dict.tool_calls.iter() + info!( + "[Tool calls]: {:?}", + response_dict + .tool_calls + .iter() .map(|tc| &tc.function) .collect::>() ); @@ -1108,7 +1184,10 @@ impl ArchFunctionHandler { } } } else { - error!("Invalid tool calls in response: {}", response_dict.error_message); + error!( + "Invalid tool calls in response: {}", + response_dict.error_message + ); ResponseMessage { role: Role::Assistant, content: Some(String::new()), @@ -1243,7 +1322,6 @@ pub async fn function_calling_chat_handler( req: Request, llm_provider_url: String, ) -> std::result::Result>, hyper::Error> { - use hermesllm::apis::openai::ChatCompletionsRequest; let whole_body = req.collect().await?.to_bytes(); @@ -1255,10 +1333,13 @@ pub async fn function_calling_chat_handler( let mut response = Response::new(full( serde_json::json!({ "error": format!("Invalid request body: {}", e) - }).to_string() + }) + .to_string(), )); *response.status_mut() = StatusCode::BAD_REQUEST; - response.headers_mut().insert("Content-Type", "application/json".parse().unwrap()); + response + .headers_mut() + .insert("Content-Type", "application/json".parse().unwrap()); return Ok(response); } }; @@ -1271,24 +1352,31 @@ pub async fn function_calling_chat_handler( // Parse as ChatCompletionsRequest let chat_request: ChatCompletionsRequest = match serde_json::from_value(body_json) { Ok(req) => { - info!("[request body]: {}", serde_json::to_string(&req).unwrap_or_default()); + info!( + "[request body]: {}", + serde_json::to_string(&req).unwrap_or_default() + ); req - }, + } Err(e) => { error!("Failed to parse request body: {}", e); let mut response = Response::new(full( serde_json::json!({ "error": format!("Invalid request body: {}", e) - }).to_string() + }) + .to_string(), )); *response.status_mut() = StatusCode::BAD_REQUEST; - response.headers_mut().insert("Content-Type", "application/json".parse().unwrap()); + response + .headers_mut() + .insert("Content-Type", "application/json".parse().unwrap()); return Ok(response); } }; // Determine which handler to use based on metadata - let use_agent_orchestrator = chat_request.metadata + let use_agent_orchestrator = chat_request + .metadata .as_ref() .and_then(|m| m.get("use_agent_orchestrator")) .and_then(|v| v.as_bool()) @@ -1309,7 +1397,10 @@ pub async fn function_calling_chat_handler( ARCH_FUNCTION_MODEL_NAME.to_string(), llm_provider_url.clone(), ); - handler.function_handler.function_calling_chat(chat_request).await + handler + .function_handler + .function_calling_chat(chat_request) + .await } else { let handler = ArchFunctionHandler::new( ARCH_FUNCTION_MODEL_NAME.to_string(), @@ -1328,7 +1419,9 @@ pub async fn function_calling_chat_handler( let mut response = Response::new(full(response_json)); *response.status_mut() = StatusCode::OK; - response.headers_mut().insert("Content-Type", "application/json".parse().unwrap()); + response + .headers_mut() + .insert("Content-Type", "application/json".parse().unwrap()); Ok(response) } @@ -1341,13 +1434,14 @@ pub async fn function_calling_chat_handler( let mut response = Response::new(full(error_response.to_string())); *response.status_mut() = StatusCode::INTERNAL_SERVER_ERROR; - response.headers_mut().insert("Content-Type", "application/json".parse().unwrap()); + response + .headers_mut() + .insert("Content-Type", "application/json".parse().unwrap()); Ok(response) } } } - // ============================================================================ // TESTS // ============================================================================ @@ -1370,10 +1464,13 @@ mod tests { assert!(config.task_prompt.contains("\\n\\n")); // Format prompt should contain literal escaped newlines and proper JSON examples - assert!(config.format_prompt.contains("\\n\\nBased on your analysis")); - assert!(config.format_prompt.contains(r#"{\"response\": \"Your response text here\"}"#)); + assert!(config + .format_prompt + .contains("\\n\\nBased on your analysis")); + assert!(config + .format_prompt + .contains(r#"{\"response\": \"Your response text here\"}"#)); assert!(config.format_prompt.contains(r#"{\"tool_calls\": [{"#)); - } #[test] @@ -1384,7 +1481,11 @@ mod tests { #[test] fn test_fix_json_string_valid() { - let handler = ArchFunctionHandler::new("test-model".to_string(), ArchFunctionConfig::default(), "http://localhost:8000".to_string()); + let handler = ArchFunctionHandler::new( + "test-model".to_string(), + ArchFunctionConfig::default(), + "http://localhost:8000".to_string(), + ); let json_str = r#"{"name": "test", "value": 123}"#; let result = handler.fix_json_string(json_str); assert!(result.is_ok()); @@ -1392,7 +1493,11 @@ mod tests { #[test] fn test_fix_json_string_missing_bracket() { - let handler = ArchFunctionHandler::new("test-model".to_string(), ArchFunctionConfig::default(), "http://localhost:8000".to_string()); + let handler = ArchFunctionHandler::new( + "test-model".to_string(), + ArchFunctionConfig::default(), + "http://localhost:8000".to_string(), + ); let json_str = r#"{"name": "test", "value": 123"#; let result = handler.fix_json_string(json_str); assert!(result.is_ok()); @@ -1402,8 +1507,13 @@ mod tests { #[test] fn test_parse_model_response_with_tool_calls() { - let handler = ArchFunctionHandler::new("test-model".to_string(), ArchFunctionConfig::default(), "http://localhost:8000".to_string()); - let content = r#"{"tool_calls": [{"name": "get_weather", "arguments": {"location": "NYC"}}]}"#; + let handler = ArchFunctionHandler::new( + "test-model".to_string(), + ArchFunctionConfig::default(), + "http://localhost:8000".to_string(), + ); + let content = + r#"{"tool_calls": [{"name": "get_weather", "arguments": {"location": "NYC"}}]}"#; let result = handler.parse_model_response(content); assert!(result.is_valid); @@ -1413,8 +1523,13 @@ mod tests { #[test] fn test_parse_model_response_with_clarification() { - let handler = ArchFunctionHandler::new("test-model".to_string(), ArchFunctionConfig::default(), "http://localhost:8000".to_string()); - let content = r#"{"required_functions": ["get_weather"], "clarification": "What location?"}"#; + let handler = ArchFunctionHandler::new( + "test-model".to_string(), + ArchFunctionConfig::default(), + "http://localhost:8000".to_string(), + ); + let content = + r#"{"required_functions": ["get_weather"], "clarification": "What location?"}"#; let result = handler.parse_model_response(content); assert!(result.is_valid); @@ -1424,7 +1539,11 @@ mod tests { #[test] fn test_convert_data_type_int_to_float() { - let handler = ArchFunctionHandler::new("test-model".to_string(), ArchFunctionConfig::default(), "http://localhost:8000".to_string()); + let handler = ArchFunctionHandler::new( + "test-model".to_string(), + ArchFunctionConfig::default(), + "http://localhost:8000".to_string(), + ); let value = json!(42); let result = handler.convert_data_type(&value, "float"); assert!(result.is_ok()); @@ -1504,13 +1623,12 @@ pub fn check_threshold( } /// Checks if a parameter is required in the function description -pub fn is_parameter_required( - function_description: &Value, - parameter_name: &str, -) -> bool { +pub fn is_parameter_required(function_description: &Value, parameter_name: &str) -> bool { if let Some(required) = function_description.get("required") { if let Some(required_arr) = required.as_array() { - return required_arr.iter().any(|v| v.as_str() == Some(parameter_name)); + return required_arr + .iter() + .any(|v| v.as_str() == Some(parameter_name)); } } false @@ -1559,12 +1677,7 @@ impl HallucinationState { pub fn new(functions: &[Tool]) -> Self { let function_properties: HashMap = functions .iter() - .map(|tool| { - ( - tool.function.name.clone(), - tool.function.parameters.clone(), - ) - }) + .map(|tool| (tool.function.name.clone(), tool.function.parameters.clone())) .collect(); Self { @@ -1620,7 +1733,10 @@ impl HallucinationState { // Function name extraction logic if self.state.as_deref() == Some("function_name") { - if !FUNC_NAME_END_TOKEN.iter().any(|&t| self.tokens.last().map_or(false, |tok| tok == t)) { + if !FUNC_NAME_END_TOKEN + .iter() + .any(|&t| self.tokens.last().is_some_and(|tok| tok == t)) + { self.mask.push(MaskToken::FunctionName); } else { self.state = None; @@ -1629,34 +1745,51 @@ impl HallucinationState { } // Check for function name start - if FUNC_NAME_START_PATTERN.iter().any(|&p| content.ends_with(p)) { + if FUNC_NAME_START_PATTERN + .iter() + .any(|&p| content.ends_with(p)) + { self.state = Some("function_name".to_string()); } // Parameter name extraction logic if self.state.as_deref() == Some("parameter_name") - && !PARAMETER_NAME_END_TOKENS.iter().any(|&t| content.ends_with(t)) { + && !PARAMETER_NAME_END_TOKENS + .iter() + .any(|&t| content.ends_with(t)) + { self.mask.push(MaskToken::ParameterName); } else if self.state.as_deref() == Some("parameter_name") - && PARAMETER_NAME_END_TOKENS.iter().any(|&t| content.ends_with(t)) { + && PARAMETER_NAME_END_TOKENS + .iter() + .any(|&t| content.ends_with(t)) + { self.state = None; self.parameter_name_done = true; self.get_parameter_name(); } else if self.parameter_name_done && !self.open_bracket - && PARAMETER_NAME_START_PATTERN.iter().any(|&p| content.ends_with(p)) { + && PARAMETER_NAME_START_PATTERN + .iter() + .any(|&p| content.ends_with(p)) + { self.state = Some("parameter_name".to_string()); } // First parameter value start - if FIRST_PARAM_NAME_START_PATTERN.iter().any(|&p| content.ends_with(p)) { + if FIRST_PARAM_NAME_START_PATTERN + .iter() + .any(|&p| content.ends_with(p)) + { self.state = Some("parameter_name".to_string()); } // Parameter value extraction logic if self.state.as_deref() == Some("parameter_value") - && !PARAMETER_VALUE_END_TOKEN.iter().any(|&t| content.ends_with(t)) { - + && !PARAMETER_VALUE_END_TOKEN + .iter() + .any(|&t| content.ends_with(t)) + { // Check for brackets if let Some(last_token) = self.tokens.last() { let open_brackets: Vec = last_token @@ -1694,8 +1827,11 @@ impl HallucinationState { && self.mask[self.mask.len() - 2] != MaskToken::ParameterValue && !self.parameter_name.is_empty() { - let last_param = self.parameter_name[self.parameter_name.len() - 1].clone(); - if let Some(func_props) = self.function_properties.get(&self.function_name) { + let last_param = + self.parameter_name[self.parameter_name.len() - 1].clone(); + if let Some(func_props) = + self.function_properties.get(&self.function_name) + { if is_parameter_required(func_props, &last_param) && !is_parameter_property(func_props, &last_param, "enum") && !self.check_parameter_name.contains_key(&last_param) @@ -1718,10 +1854,16 @@ impl HallucinationState { } } else if self.state.as_deref() == Some("parameter_value") && !self.open_bracket - && PARAMETER_VALUE_END_TOKEN.iter().any(|&t| content.ends_with(t)) { + && PARAMETER_VALUE_END_TOKEN + .iter() + .any(|&t| content.ends_with(t)) + { self.state = None; } else if self.parameter_name_done - && PARAMETER_VALUE_START_PATTERN.iter().any(|&p| content.ends_with(p)) { + && PARAMETER_VALUE_START_PATTERN + .iter() + .any(|&p| content.ends_with(p)) + { self.state = Some("parameter_value".to_string()); } @@ -1848,18 +1990,18 @@ mod hallucination_tests { let handler = ArchFunctionHandler::new( "test-model".to_string(), ArchFunctionConfig::default(), - "http://localhost:8000".to_string() + "http://localhost:8000".to_string(), ); // Test integer types assert!(handler.check_value_type(&json!(42), "integer")); assert!(handler.check_value_type(&json!(42), "int")); - assert!(!handler.check_value_type(&json!(3.14), "integer")); + assert!(!handler.check_value_type(&json!(3.15), "integer")); // Test number types (accepts both int and float) - assert!(handler.check_value_type(&json!(3.14), "number")); + assert!(handler.check_value_type(&json!(3.15), "number")); assert!(handler.check_value_type(&json!(42), "number")); - assert!(handler.check_value_type(&json!(3.14), "float")); + assert!(handler.check_value_type(&json!(3.15), "float")); // Test boolean assert!(handler.check_value_type(&json!(true), "boolean")); @@ -1890,12 +2032,16 @@ mod hallucination_tests { let handler = ArchFunctionHandler::new( "test-model".to_string(), ArchFunctionConfig::default(), - "http://localhost:8000".to_string() + "http://localhost:8000".to_string(), ); // Test valid type - no conversion needed - assert!(handler.validate_or_convert_parameter(&json!(42), "integer").unwrap()); - assert!(handler.validate_or_convert_parameter(&json!("hello"), "string").unwrap()); + assert!(handler + .validate_or_convert_parameter(&json!(42), "integer") + .unwrap()); + assert!(handler + .validate_or_convert_parameter(&json!("hello"), "string") + .unwrap()); // Test integer to float conversion (convert_data_type supports this) let result = handler.validate_or_convert_parameter(&json!(42), "float"); @@ -1910,8 +2056,12 @@ mod hallucination_tests { assert!(!result.unwrap()); // Test number accepting both int and float - assert!(handler.validate_or_convert_parameter(&json!(42), "number").unwrap()); - assert!(handler.validate_or_convert_parameter(&json!(3.14), "number").unwrap()); + assert!(handler + .validate_or_convert_parameter(&json!(42), "number") + .unwrap()); + assert!(handler + .validate_or_convert_parameter(&json!(3.15), "number") + .unwrap()); } #[test] diff --git a/crates/brightstaff/src/handlers/integration_tests.rs b/crates/brightstaff/src/handlers/integration_tests.rs index 01ea1574..29552f83 100644 --- a/crates/brightstaff/src/handlers/integration_tests.rs +++ b/crates/brightstaff/src/handlers/integration_tests.rs @@ -14,7 +14,7 @@ use crate::router::plano_orchestrator::OrchestratorService; /// 2. PipelineProcessor - executes the agent pipeline /// 3. ResponseHandler - handles response streaming #[cfg(test)] -mod integration_tests { +mod tests { use super::*; use common::configuration::{Agent, AgentFilterChain, Listener}; @@ -62,7 +62,10 @@ mod integration_tests { let agent_pipeline = AgentFilterChain { id: "terminal-agent".to_string(), - filter_chain: Some(vec!["filter-agent".to_string(), "terminal-agent".to_string()]), + filter_chain: Some(vec![ + "filter-agent".to_string(), + "terminal-agent".to_string(), + ]), description: Some("Test pipeline".to_string()), default: Some(true), }; diff --git a/crates/brightstaff/src/handlers/jsonrpc.rs b/crates/brightstaff/src/handlers/jsonrpc.rs index a34167fe..4c8f6214 100644 --- a/crates/brightstaff/src/handlers/jsonrpc.rs +++ b/crates/brightstaff/src/handlers/jsonrpc.rs @@ -2,48 +2,48 @@ use serde::{Deserialize, Serialize}; use std::collections::HashMap; pub const JSON_RPC_VERSION: &str = "2.0"; -pub const TOOL_CALL_METHOD : &str = "tools/call"; +pub const TOOL_CALL_METHOD: &str = "tools/call"; pub const MCP_INITIALIZE: &str = "initialize"; pub const MCP_INITIALIZE_NOTIFICATION: &str = "notifications/initialized"; #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(untagged)] pub enum JsonRpcId { - String(String), - Number(u64), + String(String), + Number(u64), } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct JsonRpcRequest { - pub jsonrpc: String, - pub id: JsonRpcId, - pub method: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub params: Option>, + pub jsonrpc: String, + pub id: JsonRpcId, + pub method: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub params: Option>, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct JsonRpcNotification { - pub jsonrpc: String, - pub method: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub params: Option>, + pub jsonrpc: String, + pub method: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub params: Option>, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct JsonRpcError { - pub code: i32, - pub message: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub data: Option, + pub code: i32, + pub message: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub data: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct JsonRpcResponse { - pub jsonrpc: String, - pub id: JsonRpcId, - #[serde(skip_serializing_if = "Option::is_none")] - pub result: Option>, - #[serde(skip_serializing_if = "Option::is_none")] - pub error: Option, + pub jsonrpc: String, + pub id: JsonRpcId, + #[serde(skip_serializing_if = "Option::is_none")] + pub result: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option, } diff --git a/crates/brightstaff/src/handlers/llm.rs b/crates/brightstaff/src/handlers/llm.rs index 5c5bcf01..b311976a 100644 --- a/crates/brightstaff/src/handlers/llm.rs +++ b/crates/brightstaff/src/handlers/llm.rs @@ -1,6 +1,8 @@ use bytes::Bytes; use common::configuration::{LlmProvider, ModelAlias}; -use common::consts::{ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER, REQUEST_ID_HEADER, TRACE_PARENT_HEADER}; +use common::consts::{ + ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER, REQUEST_ID_HEADER, TRACE_PARENT_HEADER, +}; use common::traces::TraceCollector; use hermesllm::apis::openai_responses::InputParam; use hermesllm::clients::{SupportedAPIsFromClient, SupportedUpstreamAPIs}; @@ -14,13 +16,14 @@ use std::sync::Arc; use tokio::sync::RwLock; use tracing::{debug, info, warn}; -use crate::router::llm_router::RouterService; -use crate::handlers::utils::{create_streaming_response, ObservableStreamProcessor, truncate_message}; use crate::handlers::router_chat::router_chat_get_upstream_model; +use crate::handlers::utils::{ + create_streaming_response, truncate_message, ObservableStreamProcessor, +}; +use crate::router::llm_router::RouterService; use crate::state::response_state_processor::ResponsesStateProcessor; use crate::state::{ - StateStorage, StateStorageError, - extract_input_items, retrieve_and_combine_input + extract_input_items, retrieve_and_combine_input, StateStorage, StateStorageError, }; use crate::tracing::operation_component; @@ -39,7 +42,6 @@ pub async fn llm_chat( trace_collector: Arc, state_storage: Option>, ) -> Result>, hyper::Error> { - let request_path = request.uri().path().to_string(); let request_headers = request.headers().clone(); let request_id = request_headers @@ -74,8 +76,14 @@ pub async fn llm_chat( )) { Ok(request) => request, Err(err) => { - warn!("[PLANO_REQ_ID:{}] | FAILURE | Failed to parse request as ProviderRequestType: {}", request_id, err); - let err_msg = format!("[PLANO_REQ_ID:{}] | FAILURE | Failed to parse request: {}", request_id, err); + warn!( + "[PLANO_REQ_ID:{}] | FAILURE | Failed to parse request as ProviderRequestType: {}", + request_id, err + ); + let err_msg = format!( + "[PLANO_REQ_ID:{}] | FAILURE | Failed to parse request: {}", + request_id, err + ); let mut bad_request = Response::new(full(err_msg)); *bad_request.status_mut() = StatusCode::BAD_REQUEST; return Ok(bad_request); @@ -85,7 +93,10 @@ pub async fn llm_chat( // === v1/responses state management: Extract input items early === let mut original_input_items = Vec::new(); let client_api = SupportedAPIsFromClient::from_endpoint(request_path.as_str()); - let is_responses_api_client = matches!(client_api, Some(SupportedAPIsFromClient::OpenAIResponsesAPI(_))); + let is_responses_api_client = matches!( + client_api, + Some(SupportedAPIsFromClient::OpenAIResponsesAPI(_)) + ); // Model alias resolution: update model field in client_request immediately // This ensures all downstream objects use the resolved model @@ -96,20 +107,28 @@ pub async fn llm_chat( // Extract tool names and user message preview for span attributes let tool_names = client_request.get_tool_names(); - let user_message_preview = client_request.get_recent_user_message() + let user_message_preview = client_request + .get_recent_user_message() .map(|msg| truncate_message(&msg, 50)); client_request.set_model(resolved_model.clone()); if client_request.remove_metadata_key("archgw_preference_config") { - debug!("[PLANO_REQ_ID:{}] Removed archgw_preference_config from metadata", request_id); + debug!( + "[PLANO_REQ_ID:{}] Removed archgw_preference_config from metadata", + request_id + ); } // === v1/responses state management: Determine upstream API and combine input if needed === // Do this BEFORE routing since routing consumes the request // Only process state if state_storage is configured let mut should_manage_state = false; - if is_responses_api_client && state_storage.is_some() { - if let ProviderRequestType::ResponsesAPIRequest(ref mut responses_req) = client_request { + if is_responses_api_client { + if let ( + ProviderRequestType::ResponsesAPIRequest(ref mut responses_req), + Some(ref state_store), + ) = (&mut client_request, &state_storage) + { // Extract original input once original_input_items = extract_input_items(&responses_req.input); @@ -120,18 +139,22 @@ pub async fn llm_chat( &request_path, &resolved_model, is_streaming_request, - ).await; + ) + .await; let upstream_api = SupportedUpstreamAPIs::from_endpoint(&upstream_path); // Only manage state if upstream is NOT OpenAIResponsesAPI (needs translation) - should_manage_state = !matches!(upstream_api, Some(SupportedUpstreamAPIs::OpenAIResponsesAPI(_))); + should_manage_state = !matches!( + upstream_api, + Some(SupportedUpstreamAPIs::OpenAIResponsesAPI(_)) + ); if should_manage_state { // Retrieve and combine conversation history if previous_response_id exists if let Some(ref prev_resp_id) = responses_req.previous_response_id { match retrieve_and_combine_input( - state_storage.as_ref().unwrap().clone(), + state_store.clone(), prev_resp_id, original_input_items, // Pass ownership instead of cloning ) @@ -166,7 +189,10 @@ pub async fn llm_chat( } } } else { - debug!("[PLANO_REQ_ID:{}] | BRIGHT_STAFF | Upstream supports ResponsesAPI natively.", request_id); + debug!( + "[PLANO_REQ_ID:{}] | BRIGHT_STAFF | Upstream supports ResponsesAPI natively.", + request_id + ); } } } @@ -177,7 +203,7 @@ pub async fn llm_chat( // Determine routing using the dedicated router_chat module let routing_result = match router_chat_get_upstream_model( router_service, - client_request, // Pass the original request - router_chat will convert it + client_request, // Pass the original request - router_chat will convert it &request_headers, trace_collector.clone(), &traceparent, @@ -257,7 +283,8 @@ pub async fn llm_chat( user_message_preview, temperature, &llm_providers, - ).await; + ) + .await; // Create base processor for metrics and tracing let base_processor = ObservableStreamProcessor::new( @@ -269,7 +296,11 @@ pub async fn llm_chat( // === v1/responses state management: Wrap with ResponsesStateProcessor === // Only wrap if we need to manage state (client is ResponsesAPI AND upstream is NOT ResponsesAPI AND state_storage is configured) - let streaming_response = if should_manage_state && !original_input_items.is_empty() && state_storage.is_some() { + let streaming_response = if let (true, false, Some(state_store)) = ( + should_manage_state, + original_input_items.is_empty(), + state_storage, + ) { // Extract Content-Encoding header to handle decompression for state parsing let content_encoding = response_headers .get("content-encoding") @@ -279,7 +310,7 @@ pub async fn llm_chat( // Wrap with state management processor to store state after response completes let state_processor = ResponsesStateProcessor::new( base_processor, - state_storage.unwrap(), + state_store, original_input_items, resolved_model.clone(), model_name.clone(), @@ -324,6 +355,7 @@ fn resolve_model_alias( } /// Builds the LLM span with all required and optional attributes. +#[allow(clippy::too_many_arguments)] async fn build_llm_span( traceparent: &str, request_path: &str, @@ -337,8 +369,8 @@ async fn build_llm_span( temperature: Option, llm_providers: &Arc>>, ) -> common::traces::Span { - use common::traces::{SpanBuilder, SpanKind, parse_traceparent}; use crate::tracing::{http, llm, OperationNameBuilder}; + use common::traces::{parse_traceparent, SpanBuilder, SpanKind}; // Calculate the upstream path based on provider configuration let upstream_path = get_upstream_path( @@ -347,13 +379,14 @@ async fn build_llm_span( request_path, resolved_model, is_streaming, - ).await; + ) + .await; // Build operation name showing path transformation if different let operation_name = if request_path != upstream_path { OperationNameBuilder::new() .with_method("POST") - .with_path(&format!("{} >> {}", request_path, upstream_path)) + .with_path(format!("{} >> {}", request_path, upstream_path)) .with_target(resolved_model) .build() } else { @@ -388,7 +421,8 @@ async fn build_llm_span( } if let Some(tools) = tool_names { - let formatted_tools = tools.iter() + let formatted_tools = tools + .iter() .map(|name| format!("{}(...)", name)) .collect::>() .join("\n"); @@ -436,8 +470,7 @@ async fn get_provider_info( // First, try to find by model name or provider name let provider = providers_lock.iter().find(|p| { - p.model.as_ref().map(|m| m == model_name).unwrap_or(false) - || p.name == model_name + p.model.as_ref().map(|m| m == model_name).unwrap_or(false) || p.name == model_name }); if let Some(provider) = provider { @@ -446,9 +479,7 @@ async fn get_provider_info( return (provider_id, prefix); } - let default_provider = providers_lock.iter().find(|p| { - p.default.unwrap_or(false) - }); + let default_provider = providers_lock.iter().find(|p| p.default.unwrap_or(false)); if let Some(provider) = default_provider { let provider_id = provider.provider_interface.to_provider_id(); diff --git a/crates/brightstaff/src/handlers/mod.rs b/crates/brightstaff/src/handlers/mod.rs index e63958d5..0bbd3454 100644 --- a/crates/brightstaff/src/handlers/mod.rs +++ b/crates/brightstaff/src/handlers/mod.rs @@ -1,13 +1,13 @@ pub mod agent_chat_completions; pub mod agent_selector; -pub mod llm; -pub mod router_chat; -pub mod models; pub mod function_calling; +pub mod jsonrpc; +pub mod llm; +pub mod models; pub mod pipeline_processor; pub mod response_handler; +pub mod router_chat; pub mod utils; -pub mod jsonrpc; #[cfg(test)] mod integration_tests; diff --git a/crates/brightstaff/src/handlers/pipeline_processor.rs b/crates/brightstaff/src/handlers/pipeline_processor.rs index a40279c3..2c1d9859 100644 --- a/crates/brightstaff/src/handlers/pipeline_processor.rs +++ b/crates/brightstaff/src/handlers/pipeline_processor.rs @@ -82,6 +82,7 @@ impl PipelineProcessor { } /// Record a span for filter execution + #[allow(clippy::too_many_arguments)] fn record_filter_span( &self, collector: &std::sync::Arc, @@ -132,6 +133,7 @@ impl PipelineProcessor { } /// Record a span for MCP protocol interactions + #[allow(clippy::too_many_arguments)] fn record_agent_filter_span( &self, collector: &std::sync::Arc, @@ -156,12 +158,12 @@ impl PipelineProcessor { .build(); let mut span_builder = SpanBuilder::new(&operation_name) - .with_span_id(span_id.unwrap_or_else(|| generate_random_span_id())) + .with_span_id(span_id.unwrap_or_else(generate_random_span_id)) .with_kind(SpanKind::Client) .with_start_time(start_time) .with_end_time(end_time) .with_attribute(http::METHOD, "POST") - .with_attribute(http::TARGET, &format!("/mcp ({})", operation.to_string())) + .with_attribute(http::TARGET, format!("/mcp ({})", operation)) .with_attribute("mcp.operation", operation.to_string()) .with_attribute("mcp.agent_id", agent_id.to_string()) .with_attribute( @@ -188,6 +190,7 @@ impl PipelineProcessor { } /// Process the filter chain of agents (all except the terminal agent) + #[allow(clippy::too_many_arguments)] pub async fn process_filter_chain( &mut self, chat_history: &[Message], @@ -1023,7 +1026,7 @@ mod tests { } }); - let sse_body = format!("event: message\ndata: {}\n\n", rpc_body.to_string()); + let sse_body = format!("event: message\ndata: {}\n\n", rpc_body); let mut server = Server::new_async().await; let _m = server @@ -1061,10 +1064,10 @@ mod tests { .await; match result { - Err(PipelineError::ClientError { status, body, .. }) => { - assert_eq!(status, 400); - assert_eq!(body, "bad tool call"); - } + Err(PipelineError::ClientError { status, body, .. }) => { + assert_eq!(status, 400); + assert_eq!(body, "bad tool call"); + } _ => panic!("Expected client error when isError flag is set"), } } diff --git a/crates/brightstaff/src/handlers/response_handler.rs b/crates/brightstaff/src/handlers/response_handler.rs index d8b3bedf..d386df1e 100644 --- a/crates/brightstaff/src/handlers/response_handler.rs +++ b/crates/brightstaff/src/handlers/response_handler.rs @@ -133,9 +133,7 @@ impl ResponseHandler { let response_headers = llm_response.headers(); let is_sse_streaming = response_headers .get(hyper::header::CONTENT_TYPE) - .map_or(false, |v| { - v.to_str().unwrap_or("").contains("text/event-stream") - }); + .is_some_and(|v| v.to_str().unwrap_or("").contains("text/event-stream")); let response_bytes = llm_response .bytes() @@ -164,7 +162,7 @@ impl ResponseHandler { match transformed_event.provider_response() { Ok(provider_response) => { if let Some(content) = provider_response.content_delta() { - accumulated_text.push_str(&content); + accumulated_text.push_str(content); } else { info!("No content delta in provider response"); } @@ -174,7 +172,7 @@ impl ResponseHandler { } } } - return Ok(accumulated_text); + Ok(accumulated_text) } else { // If not SSE, treat as regular text response let response_text = String::from_utf8(response_bytes.to_vec()).map_err(|e| { diff --git a/crates/brightstaff/src/handlers/router_chat.rs b/crates/brightstaff/src/handlers/router_chat.rs index a927a0eb..ed0d6d31 100644 --- a/crates/brightstaff/src/handlers/router_chat.rs +++ b/crates/brightstaff/src/handlers/router_chat.rs @@ -1,6 +1,6 @@ use common::configuration::ModelUsagePreference; -use common::consts::{REQUEST_ID_HEADER}; -use common::traces::{TraceCollector, SpanKind, SpanBuilder, parse_traceparent}; +use common::consts::REQUEST_ID_HEADER; +use common::traces::{parse_traceparent, SpanBuilder, SpanKind, TraceCollector}; use hermesllm::clients::endpoints::SupportedUpstreamAPIs; use hermesllm::{ProviderRequest, ProviderRequestType}; use hyper::StatusCode; @@ -9,10 +9,10 @@ use std::sync::Arc; use tracing::{debug, info, warn}; use crate::router::llm_router::RouterService; -use crate::tracing::{OperationNameBuilder, operation_component, http, routing}; +use crate::tracing::{http, operation_component, routing, OperationNameBuilder}; pub struct RoutingResult { - pub model_name: String + pub model_name: String, } pub struct RoutingError { @@ -24,7 +24,7 @@ impl RoutingError { pub fn internal_error(message: String) -> Self { Self { message, - status_code: StatusCode::INTERNAL_SERVER_ERROR + status_code: StatusCode::INTERNAL_SERVER_ERROR, } } } @@ -52,9 +52,7 @@ pub async fn router_chat_get_upstream_model( // Convert to ChatCompletionsRequest for routing (regardless of input type) let chat_request = match ProviderRequestType::try_from(( client_request, - &SupportedUpstreamAPIs::OpenAIChatCompletions( - hermesllm::apis::OpenAIApi::ChatCompletions, - ), + &SupportedUpstreamAPIs::OpenAIChatCompletions(hermesllm::apis::OpenAIApi::ChatCompletions), )) { Ok(ProviderRequestType::ChatCompletionsRequest(req)) => req, Ok( @@ -69,7 +67,10 @@ pub async fn router_chat_get_upstream_model( )); } Err(err) => { - warn!("Failed to convert request to ChatCompletionsRequest: {}", err); + warn!( + "Failed to convert request to ChatCompletionsRequest: {}", + err + ); return Err(RoutingError::internal_error(format!( "Failed to convert request: {}", err @@ -151,9 +152,7 @@ pub async fn router_chat_get_upstream_model( ) .await; - Ok(RoutingResult { - model_name - }) + Ok(RoutingResult { model_name }) } None => { // No route determined, use default model from request @@ -176,7 +175,7 @@ pub async fn router_chat_get_upstream_model( .await; Ok(RoutingResult { - model_name: default_model + model_name: default_model, }) } }, @@ -194,9 +193,10 @@ pub async fn router_chat_get_upstream_model( ) .await; - Err(RoutingError::internal_error( - format!("Failed to determine route: {}", err) - )) + Err(RoutingError::internal_error(format!( + "Failed to determine route: {}", + err + ))) } } } @@ -230,7 +230,10 @@ async fn record_routing_span( .with_end_time(std::time::SystemTime::now()) .with_attribute(http::METHOD, "POST") .with_attribute(http::TARGET, routing_api_path.to_string()) - .with_attribute(routing::ROUTE_DETERMINATION_MS, start_time.elapsed().as_millis().to_string()); + .with_attribute( + routing::ROUTE_DETERMINATION_MS, + start_time.elapsed().as_millis().to_string(), + ); // Only set parent span ID if it exists (not a root span) if let Some(parent) = parent_span_id { diff --git a/crates/brightstaff/src/handlers/utils.rs b/crates/brightstaff/src/handlers/utils.rs index 6f84c1f3..1529ef3e 100644 --- a/crates/brightstaff/src/handlers/utils.rs +++ b/crates/brightstaff/src/handlers/utils.rs @@ -1,5 +1,5 @@ use bytes::Bytes; -use common::traces::{Span, Attribute, AttributeValue, TraceCollector, Event}; +use common::traces::{Attribute, AttributeValue, Event, Span, TraceCollector}; use http_body_util::combinators::BoxBody; use http_body_util::StreamBody; use hyper::body::Frame; @@ -11,7 +11,7 @@ use tokio_stream::StreamExt; use tracing::warn; // Import tracing constants -use crate::tracing::{llm, error}; +use crate::tracing::{error, llm}; /// Trait for processing streaming chunks /// Implementors can inject custom logic during streaming (e.g., hallucination detection, logging) @@ -97,7 +97,6 @@ impl StreamProcessor for ObservableStreamProcessor { }, }); - self.span.attributes.push(Attribute { key: llm::DURATION_MS.to_string(), value: AttributeValue { @@ -119,11 +118,9 @@ impl StreamProcessor for ObservableStreamProcessor { if let Ok(start_time_nanos) = self.span.start_time_unix_nano.parse::() { // Convert ttft from milliseconds to nanoseconds and add to start time let event_timestamp = start_time_nanos + (ttft * 1_000_000); - let mut event = Event::new(llm::TIME_TO_FIRST_TOKEN_MS.to_string(), event_timestamp); - event.add_attribute( - llm::TIME_TO_FIRST_TOKEN_MS.to_string(), - ttft.to_string(), - ); + let mut event = + Event::new(llm::TIME_TO_FIRST_TOKEN_MS.to_string(), event_timestamp); + event.add_attribute(llm::TIME_TO_FIRST_TOKEN_MS.to_string(), ttft.to_string()); // Initialize events vector if needed if self.span.events.is_none() { @@ -137,7 +134,8 @@ impl StreamProcessor for ObservableStreamProcessor { } // Record the finalized span - self.collector.record_span(&self.service_name, self.span.clone()); + self.collector + .record_span(&self.service_name, self.span.clone()); } fn on_error(&mut self, error_msg: &str) { @@ -173,7 +171,8 @@ impl StreamProcessor for ObservableStreamProcessor { }); // Record the error span - self.collector.record_span(&self.service_name, self.span.clone()); + self.collector + .record_span(&self.service_name, self.span.clone()); } } diff --git a/crates/brightstaff/src/main.rs b/crates/brightstaff/src/main.rs index 8681b690..ca3d1771 100644 --- a/crates/brightstaff/src/main.rs +++ b/crates/brightstaff/src/main.rs @@ -4,13 +4,15 @@ use brightstaff::handlers::llm::llm_chat; use brightstaff::handlers::models::list_models; use brightstaff::router::llm_router::RouterService; use brightstaff::router::plano_orchestrator::OrchestratorService; -use brightstaff::state::StateStorage; -use brightstaff::state::postgresql::PostgreSQLConversationStorage; use brightstaff::state::memory::MemoryConversationalStorage; +use brightstaff::state::postgresql::PostgreSQLConversationStorage; +use brightstaff::state::StateStorage; use brightstaff::utils::tracing::init_tracer; use bytes::Bytes; use common::configuration::{Agent, Configuration}; -use common::consts::{CHAT_COMPLETIONS_PATH, MESSAGES_PATH, OPENAI_RESPONSES_API_PATH, PLANO_ORCHESTRATOR_MODEL_NAME}; +use common::consts::{ + CHAT_COMPLETIONS_PATH, MESSAGES_PATH, OPENAI_RESPONSES_API_PATH, PLANO_ORCHESTRATOR_MODEL_NAME, +}; use common::traces::TraceCollector; use http_body_util::{combinators::BoxBody, BodyExt, Empty}; use hyper::body::Incoming; @@ -105,7 +107,6 @@ async fn main() -> Result<(), Box> { PLANO_ORCHESTRATOR_MODEL_NAME.to_string(), )); - let model_aliases = Arc::new(arch_config.model_aliases.clone()); // Initialize trace collector and start background flusher @@ -127,33 +128,33 @@ async fn main() -> Result<(), Box> { // Configurable via arch_config.yaml state_storage section // If not configured, state management is disabled // Environment variables are substituted by envsubst before config is read - let state_storage: Option> = if let Some(storage_config) = &arch_config.state_storage { - let storage: Arc = match storage_config.storage_type { - common::configuration::StateStorageType::Memory => { - info!("Initialized conversation state storage: Memory"); - Arc::new(MemoryConversationalStorage::new()) - } - common::configuration::StateStorageType::Postgres => { - let connection_string = storage_config - .connection_string - .as_ref() - .expect("connection_string is required for postgres state_storage"); + let state_storage: Option> = + if let Some(storage_config) = &arch_config.state_storage { + let storage: Arc = match storage_config.storage_type { + common::configuration::StateStorageType::Memory => { + info!("Initialized conversation state storage: Memory"); + Arc::new(MemoryConversationalStorage::new()) + } + common::configuration::StateStorageType::Postgres => { + let connection_string = storage_config + .connection_string + .as_ref() + .expect("connection_string is required for postgres state_storage"); - debug!("Postgres connection string (full): {}", connection_string); - info!("Initializing conversation state storage: Postgres"); - Arc::new( - PostgreSQLConversationStorage::new(connection_string.clone()) - .await - .expect("Failed to initialize Postgres state storage"), - ) - } + debug!("Postgres connection string (full): {}", connection_string); + info!("Initializing conversation state storage: Postgres"); + Arc::new( + PostgreSQLConversationStorage::new(connection_string.clone()) + .await + .expect("Failed to initialize Postgres state storage"), + ) + } + }; + Some(storage) + } else { + info!("No state_storage configured - conversation state management disabled"); + None }; - Some(storage) - } else { - info!("No state_storage configured - conversation state management disabled"); - None - }; - loop { let (stream, _) = listener.accept().await?; @@ -208,12 +209,22 @@ async fn main() -> Result<(), Box> { } } match (req.method(), path) { - (&Method::POST, CHAT_COMPLETIONS_PATH | MESSAGES_PATH | OPENAI_RESPONSES_API_PATH) => { - let fully_qualified_url = - format!("{}{}", llm_provider_url, path); - llm_chat(req, router_service, fully_qualified_url, model_aliases, llm_providers, trace_collector, state_storage) - .with_context(parent_cx) - .await + ( + &Method::POST, + CHAT_COMPLETIONS_PATH | MESSAGES_PATH | OPENAI_RESPONSES_API_PATH, + ) => { + let fully_qualified_url = format!("{}{}", llm_provider_url, path); + llm_chat( + req, + router_service, + fully_qualified_url, + model_aliases, + llm_providers, + trace_collector, + state_storage, + ) + .with_context(parent_cx) + .await } (&Method::POST, "/function_calling") => { let fully_qualified_url = diff --git a/crates/brightstaff/src/router/orchestrator_model_v1.rs b/crates/brightstaff/src/router/orchestrator_model_v1.rs index 352361ba..5e308ecf 100644 --- a/crates/brightstaff/src/router/orchestrator_model_v1.rs +++ b/crates/brightstaff/src/router/orchestrator_model_v1.rs @@ -2,7 +2,7 @@ use std::collections::HashMap; use common::configuration::{AgentUsagePreference, OrchestrationPreference}; use hermesllm::apis::openai::{ChatCompletionsRequest, Message, MessageContent, Role}; -use serde::{Deserialize, Serialize, ser::Serialize as SerializeTrait}; +use serde::{ser::Serialize as SerializeTrait, Deserialize, Serialize}; use tracing::{debug, warn}; use super::orchestrator_model::{OrchestratorModel, OrchestratorModelError}; @@ -144,7 +144,7 @@ impl OrchestratorModelV1 { // Format routes: each route as JSON on its own line with standard spacing let agent_orchestration_json_str = agent_orchestration_values .iter() - .map(|pref| to_spaced_json(pref)) + .map(to_spaced_json) .collect::>() .join("\n"); let agent_orchestration_to_model_map: HashMap = agent_orchestrations @@ -238,24 +238,26 @@ impl OrchestratorModel for OrchestratorModelV1 { let selected_conversation_list = selected_messages_list_reversed .iter() .rev() - .map(|message| { - Message { - role: message.role.clone(), - content: MessageContent::Text(message.content.to_string()), - name: None, - tool_calls: None, - tool_call_id: None, - } + .map(|message| Message { + role: message.role.clone(), + content: MessageContent::Text(message.content.to_string()), + name: None, + tool_calls: None, + tool_call_id: None, }) .collect::>(); // Generate the orchestrator request message based on the usage preferences. // If preferences are passed in request then we use them; // Otherwise, we use the default orchestration modelpreferences. - let orchestrator_message = match convert_to_orchestrator_preferences(usage_preferences_from_request) { - Some(prefs) => generate_orchestrator_message(&prefs, &selected_conversation_list), - None => generate_orchestrator_message(&self.agent_orchestration_json_str, &selected_conversation_list), - }; + let orchestrator_message = + match convert_to_orchestrator_preferences(usage_preferences_from_request) { + Some(prefs) => generate_orchestrator_message(&prefs, &selected_conversation_list), + None => generate_orchestrator_message( + &self.agent_orchestration_json_str, + &selected_conversation_list, + ), + }; ChatCompletionsRequest { model: self.orchestration_model.clone(), @@ -280,7 +282,8 @@ impl OrchestratorModel for OrchestratorModelV1 { return Ok(None); } let orchestrator_resp_fixed = fix_json_response(content); - let orchestrator_response: AgentOrchestratorResponse = serde_json::from_str(orchestrator_resp_fixed.as_str())?; + let orchestrator_response: AgentOrchestratorResponse = + serde_json::from_str(orchestrator_resp_fixed.as_str())?; let selected_routes = orchestrator_response.route.unwrap_or_default(); @@ -320,7 +323,11 @@ impl OrchestratorModel for OrchestratorModelV1 { } else { // If no usage preferences are passed in request then use the default orchestration model preferences for selected_route in valid_routes { - if let Some(model) = self.agent_orchestration_to_model_map.get(&selected_route).cloned() { + if let Some(model) = self + .agent_orchestration_to_model_map + .get(&selected_route) + .cloned() + { result.push((selected_route, model)); } else { warn!( @@ -375,7 +382,7 @@ fn convert_to_orchestrator_preferences( // Format routes: each route as JSON on its own line with standard spacing let routes_str = orchestration_preferences .iter() - .map(|pref| to_spaced_json(pref)) + .map(to_spaced_json) .collect::>() .join("\n"); @@ -425,7 +432,10 @@ mod tests { // CRITICAL: Test that colons inside string values are NOT modified let with_colon = serde_json::json!({"name": "foo:bar", "url": "http://example.com"}); let result = to_spaced_json(&with_colon); - assert_eq!(result, r#"{"name": "foo:bar", "url": "http://example.com"}"#); + assert_eq!( + result, + r#"{"name": "foo:bar", "url": "http://example.com"}"# + ); // Test empty object and array let empty_obj = serde_json::json!({}); @@ -446,7 +456,8 @@ mod tests { }); let result = to_spaced_json(&complex); // Verify URLs with colons are preserved correctly - assert!(result.contains(r#""urls": ["https://api.example.com:8080/path", "file:///local/path"]"#)); + assert!(result + .contains(r#""urls": ["https://api.example.com:8080/path", "file:///local/path"]"#)); // Verify spacing format assert!(result.contains(r#""type": "object""#)); assert!(result.contains(r#""properties": {}"#)); @@ -497,10 +508,16 @@ If no routes are needed, return an empty list for `route`. ] } "#; - let agent_orchestrations = - serde_json::from_str::>>(orchestrations_str).unwrap(); + let agent_orchestrations = serde_json::from_str::< + HashMap>, + >(orchestrations_str) + .unwrap(); let orchestration_model = "test-model".to_string(); - let orchestrator = OrchestratorModelV1::new(agent_orchestrations, orchestration_model.clone(), usize::MAX); + let orchestrator = OrchestratorModelV1::new( + agent_orchestrations, + orchestration_model.clone(), + usize::MAX, + ); let conversation_str = r#" [ @@ -568,7 +585,11 @@ If no routes are needed, return an empty list for `route`. // Empty orchestrations map - not used when usage_preferences are provided let agent_orchestrations: HashMap> = HashMap::new(); let orchestration_model = "test-model".to_string(); - let orchestrator = OrchestratorModelV1::new(agent_orchestrations, orchestration_model.clone(), usize::MAX); + let orchestrator = OrchestratorModelV1::new( + agent_orchestrations, + orchestration_model.clone(), + usize::MAX, + ); let conversation_str = r#" [ @@ -640,10 +661,13 @@ If no routes are needed, return an empty list for `route`. ] } "#; - let agent_orchestrations = - serde_json::from_str::>>(orchestrations_str).unwrap(); + let agent_orchestrations = serde_json::from_str::< + HashMap>, + >(orchestrations_str) + .unwrap(); let orchestration_model = "test-model".to_string(); - let orchestrator = OrchestratorModelV1::new(agent_orchestrations, orchestration_model.clone(), 235); + let orchestrator = + OrchestratorModelV1::new(agent_orchestrations, orchestration_model.clone(), 235); let conversation_str = r#" [ @@ -709,11 +733,14 @@ If no routes are needed, return an empty list for `route`. ] } "#; - let agent_orchestrations = - serde_json::from_str::>>(orchestrations_str).unwrap(); + let agent_orchestrations = serde_json::from_str::< + HashMap>, + >(orchestrations_str) + .unwrap(); let orchestration_model = "test-model".to_string(); - let orchestrator = OrchestratorModelV1::new(agent_orchestrations, orchestration_model.clone(), 200); + let orchestrator = + OrchestratorModelV1::new(agent_orchestrations, orchestration_model.clone(), 200); let conversation_str = r#" [ @@ -787,10 +814,13 @@ If no routes are needed, return an empty list for `route`. ] } "#; - let agent_orchestrations = - serde_json::from_str::>>(orchestrations_str).unwrap(); + let agent_orchestrations = serde_json::from_str::< + HashMap>, + >(orchestrations_str) + .unwrap(); let orchestration_model = "test-model".to_string(); - let orchestrator = OrchestratorModelV1::new(agent_orchestrations, orchestration_model.clone(), 230); + let orchestrator = + OrchestratorModelV1::new(agent_orchestrations, orchestration_model.clone(), 230); let conversation_str = r#" [ @@ -871,10 +901,16 @@ If no routes are needed, return an empty list for `route`. ] } "#; - let agent_orchestrations = - serde_json::from_str::>>(orchestrations_str).unwrap(); + let agent_orchestrations = serde_json::from_str::< + HashMap>, + >(orchestrations_str) + .unwrap(); let orchestration_model = "test-model".to_string(); - let orchestrator = OrchestratorModelV1::new(agent_orchestrations, orchestration_model.clone(), usize::MAX); + let orchestrator = OrchestratorModelV1::new( + agent_orchestrations, + orchestration_model.clone(), + usize::MAX, + ); let conversation_str = r#" [ @@ -957,10 +993,16 @@ If no routes are needed, return an empty list for `route`. ] } "#; - let agent_orchestrations = - serde_json::from_str::>>(orchestrations_str).unwrap(); + let agent_orchestrations = serde_json::from_str::< + HashMap>, + >(orchestrations_str) + .unwrap(); let orchestration_model = "test-model".to_string(); - let orchestrator = OrchestratorModelV1::new(agent_orchestrations, orchestration_model.clone(), usize::MAX); + let orchestrator = OrchestratorModelV1::new( + agent_orchestrations, + orchestration_model.clone(), + usize::MAX, + ); let conversation_str = r#" [ @@ -1034,10 +1076,13 @@ If no routes are needed, return an empty list for `route`. ] } "#; - let agent_orchestrations = - serde_json::from_str::>>(orchestrations_str).unwrap(); + let agent_orchestrations = serde_json::from_str::< + HashMap>, + >(orchestrations_str) + .unwrap(); - let orchestrator = OrchestratorModelV1::new(agent_orchestrations, "test-model".to_string(), 2000); + let orchestrator = + OrchestratorModelV1::new(agent_orchestrations, "test-model".to_string(), 2000); // Case 1: Valid JSON with single route in array let input = r#"{"route": ["Image generation"]}"#; diff --git a/crates/brightstaff/src/router/plano_orchestrator.rs b/crates/brightstaff/src/router/plano_orchestrator.rs index 62efedce..5aff4b11 100644 --- a/crates/brightstaff/src/router/plano_orchestrator.rs +++ b/crates/brightstaff/src/router/plano_orchestrator.rs @@ -34,10 +34,7 @@ pub enum OrchestrationError { pub type Result = std::result::Result; impl OrchestratorService { - pub fn new( - orchestrator_url: String, - orchestration_model_name: String, - ) -> Self { + pub fn new(orchestrator_url: String, orchestration_model_name: String) -> Self { // Empty agent orchestrations - will be provided via usage_preferences in requests let agent_orchestrations: HashMap> = HashMap::new(); diff --git a/crates/brightstaff/src/state/memory.rs b/crates/brightstaff/src/state/memory.rs index d805d655..be4d8232 100644 --- a/crates/brightstaff/src/state/memory.rs +++ b/crates/brightstaff/src/state/memory.rs @@ -85,13 +85,19 @@ impl StateStorage for MemoryConversationalStorage { #[cfg(test)] mod tests { use super::*; - use hermesllm::apis::openai_responses::{InputItem, InputMessage, MessageRole, InputContent, MessageContent}; + use hermesllm::apis::openai_responses::{ + InputContent, InputItem, InputMessage, MessageContent, MessageRole, + }; fn create_test_state(response_id: &str, num_messages: usize) -> OpenAIConversationState { let mut input_items = Vec::new(); for i in 0..num_messages { input_items.push(InputItem::Message(InputMessage { - role: if i % 2 == 0 { MessageRole::User } else { MessageRole::Assistant }, + role: if i % 2 == 0 { + MessageRole::User + } else { + MessageRole::Assistant + }, content: MessageContent::Items(vec![InputContent::InputText { text: format!("Message {}", i), }]), @@ -252,7 +258,9 @@ mod tests { let merged = storage.merge(&prev_state, current_input); // Verify order: prev messages first, then current - let InputItem::Message(msg) = &merged[0] else { panic!("Expected Message") }; + let InputItem::Message(msg) = &merged[0] else { + panic!("Expected Message") + }; match &msg.content { MessageContent::Items(items) => match &items[0] { InputContent::InputText { text } => assert_eq!(text, "Message 0"), @@ -261,7 +269,9 @@ mod tests { _ => panic!("Expected MessageContent::Items"), } - let InputItem::Message(msg) = &merged[2] else { panic!("Expected Message") }; + let InputItem::Message(msg) = &merged[2] else { + panic!("Expected Message") + }; match &msg.content { MessageContent::Items(items) => match &items[0] { InputContent::InputText { text } => assert_eq!(text, "Message 2"), @@ -404,7 +414,8 @@ mod tests { let current_input = vec![InputItem::Message(InputMessage { role: MessageRole::User, content: MessageContent::Items(vec![InputContent::InputText { - text: "Function result: {\"temperature\": 72, \"condition\": \"sunny\"}".to_string(), + text: "Function result: {\"temperature\": 72, \"condition\": \"sunny\"}" + .to_string(), }]), })]; @@ -415,7 +426,9 @@ mod tests { assert_eq!(merged.len(), 3); // Verify the order and content - let InputItem::Message(msg1) = &merged[0] else { panic!("Expected Message") }; + let InputItem::Message(msg1) = &merged[0] else { + panic!("Expected Message") + }; assert!(matches!(msg1.role, MessageRole::User)); match &msg1.content { MessageContent::Items(items) => match &items[0] { @@ -427,7 +440,9 @@ mod tests { _ => panic!("Expected MessageContent::Items"), } - let InputItem::Message(msg2) = &merged[1] else { panic!("Expected Message") }; + let InputItem::Message(msg2) = &merged[1] else { + panic!("Expected Message") + }; assert!(matches!(msg2.role, MessageRole::Assistant)); match &msg2.content { MessageContent::Items(items) => match &items[0] { @@ -439,7 +454,9 @@ mod tests { _ => panic!("Expected MessageContent::Items"), } - let InputItem::Message(msg3) = &merged[2] else { panic!("Expected Message") }; + let InputItem::Message(msg3) = &merged[2] else { + panic!("Expected Message") + }; assert!(matches!(msg3.role, MessageRole::User)); match &msg3.content { MessageContent::Items(items) => match &items[0] { @@ -508,11 +525,15 @@ mod tests { assert_eq!(merged.len(), 5); // Verify first item is original user message - let InputItem::Message(first) = &merged[0] else { panic!("Expected Message") }; + let InputItem::Message(first) = &merged[0] else { + panic!("Expected Message") + }; assert!(matches!(first.role, MessageRole::User)); // Verify last two are function outputs - let InputItem::Message(second_last) = &merged[3] else { panic!("Expected Message") }; + let InputItem::Message(second_last) = &merged[3] else { + panic!("Expected Message") + }; assert!(matches!(second_last.role, MessageRole::User)); match &second_last.content { MessageContent::Items(items) => match &items[0] { @@ -522,7 +543,9 @@ mod tests { _ => panic!("Expected MessageContent::Items"), } - let InputItem::Message(last) = &merged[4] else { panic!("Expected Message") }; + let InputItem::Message(last) = &merged[4] else { + panic!("Expected Message") + }; assert!(matches!(last.role, MessageRole::User)); match &last.content { MessageContent::Items(items) => match &items[0] { @@ -590,7 +613,9 @@ mod tests { assert_eq!(merged.len(), 5); // Verify the entire conversation flow is preserved - let InputItem::Message(first) = &merged[0] else { panic!("Expected Message") }; + let InputItem::Message(first) = &merged[0] else { + panic!("Expected Message") + }; match &first.content { MessageContent::Items(items) => match &items[0] { InputContent::InputText { text } => assert!(text.contains("What's the weather")), @@ -599,7 +624,9 @@ mod tests { _ => panic!("Expected MessageContent::Items"), } - let InputItem::Message(last) = &merged[4] else { panic!("Expected Message") }; + let InputItem::Message(last) = &merged[4] else { + panic!("Expected Message") + }; match &last.content { MessageContent::Items(items) => match &items[0] { InputContent::InputText { text } => assert!(text.contains("umbrella")), diff --git a/crates/brightstaff/src/state/mod.rs b/crates/brightstaff/src/state/mod.rs index f2b96da0..ce3ec8ae 100644 --- a/crates/brightstaff/src/state/mod.rs +++ b/crates/brightstaff/src/state/mod.rs @@ -1,14 +1,16 @@ use async_trait::async_trait; -use hermesllm::apis::openai_responses::{InputItem, InputMessage, InputContent, MessageContent, MessageRole, InputParam}; +use hermesllm::apis::openai_responses::{ + InputContent, InputItem, InputMessage, InputParam, MessageContent, MessageRole, +}; use serde::{Deserialize, Serialize}; use std::error::Error; use std::fmt; use std::sync::Arc; -use tracing::{debug}; +use tracing::debug; pub mod memory; -pub mod response_state_processor; pub mod postgresql; +pub mod response_state_processor; /// Represents the conversational state for a v1/responses request /// Contains the complete input/output history that can be restored @@ -47,7 +49,9 @@ pub enum StateStorageError { impl fmt::Display for StateStorageError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - StateStorageError::NotFound(id) => write!(f, "Conversation state not found for response_id: {}", id), + StateStorageError::NotFound(id) => { + write!(f, "Conversation state not found for response_id: {}", id) + } StateStorageError::StorageError(msg) => write!(f, "Storage error: {}", msg), StateStorageError::SerializationError(msg) => write!(f, "Serialization error: {}", msg), } @@ -96,8 +100,6 @@ pub trait StateStorage: Send + Sync { } } - - /// Storage backend type enum #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum StorageBackend { @@ -106,7 +108,7 @@ pub enum StorageBackend { } impl StorageBackend { - pub fn from_str(s: &str) -> Option { + pub fn parse_backend(s: &str) -> Option { match s.to_lowercase().as_str() { "memory" => Some(StorageBackend::Memory), "supabase" => Some(StorageBackend::Supabase), @@ -139,7 +141,6 @@ pub async fn retrieve_and_combine_input( previous_response_id: &str, current_input: Vec, ) -> Result, StateStorageError> { - // First get the previous state let prev_state = storage.get(previous_response_id).await?; let combined_input = storage.merge(&prev_state, current_input); diff --git a/crates/brightstaff/src/state/postgresql.rs b/crates/brightstaff/src/state/postgresql.rs index 529f27e9..fe27580e 100644 --- a/crates/brightstaff/src/state/postgresql.rs +++ b/crates/brightstaff/src/state/postgresql.rs @@ -149,13 +149,12 @@ impl StateStorage for PostgreSQLConversationStorage { let provider: String = row.get("provider"); // Deserialize input_items from JSONB - let input_items = - serde_json::from_value(input_items_json).map_err(|e| { - StateStorageError::StorageError(format!( - "Failed to deserialize input_items: {}", - e - )) - })?; + let input_items = serde_json::from_value(input_items_json).map_err(|e| { + StateStorageError::StorageError(format!( + "Failed to deserialize input_items: {}", + e + )) + })?; Ok(OpenAIConversationState { response_id, @@ -230,7 +229,9 @@ Run that SQL file against your database before using this storage backend. #[cfg(test)] mod tests { use super::*; - use hermesllm::apis::openai_responses::{InputContent, InputItem, InputMessage, MessageContent, MessageRole}; + use hermesllm::apis::openai_responses::{ + InputContent, InputItem, InputMessage, MessageContent, MessageRole, + }; fn create_test_state(response_id: &str) -> OpenAIConversationState { OpenAIConversationState { @@ -320,7 +321,10 @@ mod tests { let result = storage.get("nonexistent_id").await; assert!(result.is_err()); - assert!(matches!(result.unwrap_err(), StateStorageError::NotFound(_))); + assert!(matches!( + result.unwrap_err(), + StateStorageError::NotFound(_) + )); } #[tokio::test] @@ -372,7 +376,10 @@ mod tests { let result = storage.delete("nonexistent_id").await; assert!(result.is_err()); - assert!(matches!(result.unwrap_err(), StateStorageError::NotFound(_))); + assert!(matches!( + result.unwrap_err(), + StateStorageError::NotFound(_) + )); } #[tokio::test] @@ -423,9 +430,13 @@ mod tests { println!("✅ Data written to Supabase!"); println!("Check your Supabase dashboard:"); - println!(" SELECT * FROM conversation_states WHERE response_id = 'manual_test_verification';"); + println!( + " SELECT * FROM conversation_states WHERE response_id = 'manual_test_verification';" + ); println!("\nTo cleanup, run:"); - println!(" DELETE FROM conversation_states WHERE response_id = 'manual_test_verification';"); + println!( + " DELETE FROM conversation_states WHERE response_id = 'manual_test_verification';" + ); // DON'T cleanup - leave it for manual verification } diff --git a/crates/brightstaff/src/state/response_state_processor.rs b/crates/brightstaff/src/state/response_state_processor.rs index b3ce6787..3d1e8673 100644 --- a/crates/brightstaff/src/state/response_state_processor.rs +++ b/crates/brightstaff/src/state/response_state_processor.rs @@ -1,13 +1,11 @@ use bytes::Bytes; use flate2::read::GzDecoder; -use hermesllm::apis::openai_responses::{ - InputItem, OutputItem, ResponsesAPIStreamEvent, -}; +use hermesllm::apis::openai_responses::{InputItem, OutputItem, ResponsesAPIStreamEvent}; use hermesllm::apis::streaming_shapes::sse::SseStreamIter; use hermesllm::transforms::response::output_to_input::outputs_to_inputs; use std::io::Read; use std::sync::Arc; -use tracing::{info, debug, warn}; +use tracing::{debug, info, warn}; use crate::handlers::utils::StreamProcessor; use crate::state::{OpenAIConversationState, StateStorage}; @@ -53,6 +51,7 @@ pub struct ResponsesStateProcessor { } impl ResponsesStateProcessor

{ + #[allow(clippy::too_many_arguments)] pub fn new( inner: P, storage: Arc, @@ -139,20 +138,19 @@ impl ResponsesStateProcessor

{ for event in sse_iter { // Only process data lines (skip event-only lines) if let Some(data_str) = &event.data { - // Try to parse as ResponsesAPIStreamEvent - if let Ok(stream_event) = serde_json::from_str::(data_str) { - // Check if this is a ResponseCompleted event - if let ResponsesAPIStreamEvent::ResponseCompleted { response, .. } = stream_event { - info!( - "[PLANO_REQ_ID:{}] | STATE_PROCESSOR | Captured streaming response.completed: response_id={}, output_items={}", - self.request_id, - response.id, - response.output.len() - ); - self.response_id = Some(response.id.clone()); - self.output_items = Some(response.output.clone()); - return; // Found what we need, exit early - } + // Try to parse as ResponsesAPIStreamEvent and check if it's a ResponseCompleted event + if let Ok(ResponsesAPIStreamEvent::ResponseCompleted { response, .. }) = + serde_json::from_str::(data_str) + { + info!( + "[PLANO_REQ_ID:{}] | STATE_PROCESSOR | Captured streaming response.completed: response_id={}, output_items={}", + self.request_id, + response.id, + response.output.len() + ); + self.response_id = Some(response.id.clone()); + self.output_items = Some(response.output.clone()); + return; // Found what we need, exit early } } } @@ -172,7 +170,9 @@ impl ResponsesStateProcessor

{ let decompressed = self.decompress_buffer(); // Parse complete JSON response - match serde_json::from_slice::(&decompressed) { + match serde_json::from_slice::( + &decompressed, + ) { Ok(response) => { info!( "[PLANO_REQ_ID:{}] | STATE_PROCESSOR | Captured non-streaming response: response_id={}, output_items={}", diff --git a/crates/brightstaff/src/tracing/constants.rs b/crates/brightstaff/src/tracing/constants.rs index 1edc4d3d..aac48802 100644 --- a/crates/brightstaff/src/tracing/constants.rs +++ b/crates/brightstaff/src/tracing/constants.rs @@ -2,11 +2,9 @@ /// /// This module defines standard attribute keys following OTEL semantic conventions. /// See: https://opentelemetry.io/docs/specs/semconv/ - // ============================================================================= // Span Attributes - HTTP // ============================================================================= - /// Semantic conventions for HTTP-related span attributes pub mod http { /// HTTP request method diff --git a/crates/brightstaff/src/tracing/mod.rs b/crates/brightstaff/src/tracing/mod.rs index bacc9571..4c7f099f 100644 --- a/crates/brightstaff/src/tracing/mod.rs +++ b/crates/brightstaff/src/tracing/mod.rs @@ -1,3 +1,3 @@ mod constants; -pub use constants::{OperationNameBuilder, operation_component, http, llm, error, routing}; +pub use constants::{error, http, llm, operation_component, routing, OperationNameBuilder}; diff --git a/crates/common/src/configuration.rs b/crates/common/src/configuration.rs index 005e1264..07d3311b 100644 --- a/crates/common/src/configuration.rs +++ b/crates/common/src/configuration.rs @@ -295,11 +295,14 @@ impl serde::Serialize for OrchestrationPreference { let mut state = serializer.serialize_struct("OrchestrationPreference", 3)?; state.serialize_field("name", &self.name)?; state.serialize_field("description", &self.description)?; - state.serialize_field("parameters", &serde_json::json!({ - "type": "object", - "properties": {}, - "required": [] - }))?; + state.serialize_field( + "parameters", + &serde_json::json!({ + "type": "object", + "properties": {}, + "required": [] + }), + )?; state.end() } } @@ -489,7 +492,10 @@ mod test { assert_eq!(config.version, "v0.3.0"); if let Some(prompt_targets) = &config.prompt_targets { - assert!(!prompt_targets.is_empty(), "prompt_targets should not be empty if present"); + assert!( + !prompt_targets.is_empty(), + "prompt_targets should not be empty if present" + ); } if let Some(tracing) = config.tracing.as_ref() { @@ -510,19 +516,48 @@ mod test { .expect("reference config file not found"); let config: super::Configuration = serde_yaml::from_str(&ref_config).unwrap(); if let Some(prompt_targets) = &config.prompt_targets { - if let Some(prompt_target) = prompt_targets.iter().find(|p| p.name == "reboot_network_device") { + if let Some(prompt_target) = prompt_targets + .iter() + .find(|p| p.name == "reboot_network_device") + { let chat_completion_tool: super::ChatCompletionTool = prompt_target.into(); assert_eq!(chat_completion_tool.tool_type, ToolType::Function); assert_eq!(chat_completion_tool.function.name, "reboot_network_device"); - assert_eq!(chat_completion_tool.function.description, "Reboot a specific network device"); + assert_eq!( + chat_completion_tool.function.description, + "Reboot a specific network device" + ); assert_eq!(chat_completion_tool.function.parameters.properties.len(), 2); - assert!(chat_completion_tool.function.parameters.properties.contains_key("device_id")); - let device_id_param = chat_completion_tool.function.parameters.properties.get("device_id").unwrap(); - assert_eq!(device_id_param.parameter_type, crate::api::open_ai::ParameterType::String); - assert_eq!(device_id_param.description, "Identifier of the network device to reboot.".to_string()); + assert!(chat_completion_tool + .function + .parameters + .properties + .contains_key("device_id")); + let device_id_param = chat_completion_tool + .function + .parameters + .properties + .get("device_id") + .unwrap(); + assert_eq!( + device_id_param.parameter_type, + crate::api::open_ai::ParameterType::String + ); + assert_eq!( + device_id_param.description, + "Identifier of the network device to reboot.".to_string() + ); assert_eq!(device_id_param.required, Some(true)); - let confirmation_param = chat_completion_tool.function.parameters.properties.get("confirmation").unwrap(); - assert_eq!(confirmation_param.parameter_type, crate::api::open_ai::ParameterType::Bool); + let confirmation_param = chat_completion_tool + .function + .parameters + .properties + .get("confirmation") + .unwrap(); + assert_eq!( + confirmation_param.parameter_type, + crate::api::open_ai::ParameterType::Bool + ); } } } diff --git a/crates/common/src/consts.rs b/crates/common/src/consts.rs index 1ef254d8..e1b5dc1a 100644 --- a/crates/common/src/consts.rs +++ b/crates/common/src/consts.rs @@ -32,6 +32,6 @@ pub const OTEL_COLLECTOR_HTTP: &str = "opentelemetry_collector_http"; pub const OTEL_POST_PATH: &str = "/v1/traces"; pub const LLM_ROUTE_HEADER: &str = "x-arch-llm-route"; pub const ENVOY_RETRY_HEADER: &str = "x-envoy-max-retries"; -pub const BRIGHT_STAFF_SERVICE_NAME : &str = "brightstaff"; +pub const BRIGHT_STAFF_SERVICE_NAME: &str = "brightstaff"; pub const PLANO_ORCHESTRATOR_MODEL_NAME: &str = "Plano-Orchestrator"; pub const ARCH_FC_CLUSTER: &str = "arch"; diff --git a/crates/common/src/lib.rs b/crates/common/src/lib.rs index 9c8f5787..aba27b9b 100644 --- a/crates/common/src/lib.rs +++ b/crates/common/src/lib.rs @@ -10,6 +10,6 @@ pub mod ratelimit; pub mod routing; pub mod stats; pub mod tokenizer; -pub mod tracing; pub mod traces; +pub mod tracing; pub mod utils; diff --git a/crates/common/src/routing.rs b/crates/common/src/routing.rs index 8813c92d..918435a6 100644 --- a/crates/common/src/routing.rs +++ b/crates/common/src/routing.rs @@ -41,7 +41,8 @@ pub fn get_llm_provider( llm_providers .iter() .filter(|(_, provider)| { - provider.model + provider + .model .as_ref() .map(|m| !m.starts_with("Arch")) .unwrap_or(true) diff --git a/crates/common/src/traces/collector.rs b/crates/common/src/traces/collector.rs index 0fc407b2..e26f544e 100644 --- a/crates/common/src/traces/collector.rs +++ b/crates/common/src/traces/collector.rs @@ -1,5 +1,5 @@ -use super::shapes::Span; use super::resource_span_builder::ResourceSpanBuilder; +use super::shapes::Span; use std::collections::{HashMap, VecDeque}; use std::sync::Arc; use tokio::sync::Mutex; @@ -160,7 +160,11 @@ impl TraceCollector { } let total_spans: usize = service_batches.iter().map(|(_, spans)| spans.len()).sum(); - debug!("Flushing {} spans across {} services to OTEL collector", total_spans, service_batches.len()); + debug!( + "Flushing {} spans across {} services to OTEL collector", + total_spans, + service_batches.len() + ); // Build canonical OTEL payload structure - one ResourceSpan per service let resource_spans = self.build_resource_spans(service_batches); @@ -178,7 +182,10 @@ impl TraceCollector { } /// Build OTEL-compliant resource spans from collected spans, one ResourceSpan per service - fn build_resource_spans(&self, service_batches: Vec<(String, Vec)>) -> Vec { + fn build_resource_spans( + &self, + service_batches: Vec<(String, Vec)>, + ) -> Vec { service_batches .into_iter() .map(|(service_name, spans)| { diff --git a/crates/common/src/traces/constants.rs b/crates/common/src/traces/constants.rs index 09bdecd5..9e637c57 100644 --- a/crates/common/src/traces/constants.rs +++ b/crates/common/src/traces/constants.rs @@ -1,7 +1,6 @@ /// OpenTelemetry semantic convention constants for tracing /// /// These constants ensure consistency across the codebase and prevent typos - /// Resource attribute keys following OTEL semantic conventions pub mod resource { /// Logical name of the service diff --git a/crates/common/src/traces/mod.rs b/crates/common/src/traces/mod.rs index c4197995..6181f194 100644 --- a/crates/common/src/traces/mod.rs +++ b/crates/common/src/traces/mod.rs @@ -1,9 +1,9 @@ // Original tracing types (OTEL structures) mod shapes; // New tracing utilities -mod span_builder; -mod resource_span_builder; mod constants; +mod resource_span_builder; +mod span_builder; #[cfg(feature = "trace-collection")] mod collector; @@ -13,14 +13,14 @@ mod tests; // Re-export original types pub use shapes::{ - Span, Event, Traceparent, TraceparentNewError, - ResourceSpan, Resource, ScopeSpan, Scope, Attribute, AttributeValue, + Attribute, AttributeValue, Event, Resource, ResourceSpan, Scope, ScopeSpan, Span, Traceparent, + TraceparentNewError, }; // Re-export new utilities -pub use span_builder::{SpanBuilder, SpanKind, generate_random_span_id}; -pub use resource_span_builder::ResourceSpanBuilder; pub use constants::*; +pub use resource_span_builder::ResourceSpanBuilder; +pub use span_builder::{generate_random_span_id, SpanBuilder, SpanKind}; #[cfg(feature = "trace-collection")] -pub use collector::{TraceCollector, parse_traceparent}; +pub use collector::{parse_traceparent, TraceCollector}; diff --git a/crates/common/src/traces/resource_span_builder.rs b/crates/common/src/traces/resource_span_builder.rs index 3e0dd88f..42cdfdae 100644 --- a/crates/common/src/traces/resource_span_builder.rs +++ b/crates/common/src/traces/resource_span_builder.rs @@ -1,5 +1,5 @@ -use super::shapes::{ResourceSpan, Resource, ScopeSpan, Scope, Span, Attribute, AttributeValue}; use super::constants::{resource, scope}; +use super::shapes::{Attribute, AttributeValue, Resource, ResourceSpan, Scope, ScopeSpan, Span}; use std::collections::HashMap; /// Builder for creating OTEL ResourceSpan structures @@ -26,7 +26,11 @@ impl ResourceSpanBuilder { } /// Add a resource attribute (e.g., deployment.environment, host.name) - pub fn with_resource_attribute(mut self, key: impl Into, value: impl Into) -> Self { + pub fn with_resource_attribute( + mut self, + key: impl Into, + value: impl Into, + ) -> Self { self.resource_attributes.insert(key.into(), value.into()); self } @@ -58,14 +62,12 @@ impl ResourceSpanBuilder { /// Build the ResourceSpan pub fn build(self) -> ResourceSpan { // Build resource attributes - let mut attributes = vec![ - Attribute { - key: resource::SERVICE_NAME.to_string(), - value: AttributeValue { - string_value: Some(self.service_name), - }, - } - ]; + let mut attributes = vec![Attribute { + key: resource::SERVICE_NAME.to_string(), + value: AttributeValue { + string_value: Some(self.service_name), + }, + }]; // Add custom resource attributes for (key, value) in self.resource_attributes { diff --git a/crates/common/src/traces/span_builder.rs b/crates/common/src/traces/span_builder.rs index e07cfab9..fa162afa 100644 --- a/crates/common/src/traces/span_builder.rs +++ b/crates/common/src/traces/span_builder.rs @@ -1,4 +1,4 @@ -use super::shapes::{Span, Attribute, AttributeValue}; +use super::shapes::{Attribute, AttributeValue, Span}; use std::collections::HashMap; use std::time::SystemTime; @@ -116,10 +116,11 @@ impl SpanBuilder { let end_nanos = system_time_to_nanos(end_time); // Generate trace_id if not provided - let trace_id = self.trace_id.unwrap_or_else(|| generate_random_trace_id()); + let trace_id = self.trace_id.unwrap_or_else(generate_random_trace_id); // Create attributes in OTEL format - let attributes: Vec = self.attributes + let attributes: Vec = self + .attributes .into_iter() .map(|(key, value)| Attribute { key, @@ -132,7 +133,7 @@ impl SpanBuilder { // Build span directly without going through Span::new() Span { trace_id, - span_id: self.span_id.unwrap_or_else(|| generate_random_span_id()), + span_id: self.span_id.unwrap_or_else(generate_random_span_id), parent_span_id: self.parent_span_id, name: self.name, start_time_unix_nano: format!("{}", start_nanos), diff --git a/crates/common/src/traces/tests/mock_otel_collector.rs b/crates/common/src/traces/tests/mock_otel_collector.rs index 8a154145..8c8e770d 100644 --- a/crates/common/src/traces/tests/mock_otel_collector.rs +++ b/crates/common/src/traces/tests/mock_otel_collector.rs @@ -21,10 +21,7 @@ use tokio::sync::RwLock; type SharedTraces = Arc>>; /// POST /v1/traces - capture incoming OTLP payload -async fn post_traces( - State(traces): State, - Json(payload): Json, -) -> StatusCode { +async fn post_traces(State(traces): State, Json(payload): Json) -> StatusCode { traces.write().await.push(payload); StatusCode::OK } @@ -67,9 +64,7 @@ impl MockOtelCollector { let address = format!("http://127.0.0.1:{}", addr.port()); let server_handle = tokio::spawn(async move { - axum::serve(listener, app) - .await - .expect("Server failed"); + axum::serve(listener, app).await.expect("Server failed"); }); // Give server a moment to start diff --git a/crates/common/src/traces/tests/trace_integration_test.rs b/crates/common/src/traces/tests/trace_integration_test.rs index a3c8a6ba..5f41a2c3 100644 --- a/crates/common/src/traces/tests/trace_integration_test.rs +++ b/crates/common/src/traces/tests/trace_integration_test.rs @@ -36,9 +36,12 @@ fn extract_spans(payloads: &[Value]) -> Vec<&Value> { for payload in payloads { if let Some(resource_spans) = payload.get("resourceSpans").and_then(|v| v.as_array()) { for resource_span in resource_spans { - if let Some(scope_spans) = resource_span.get("scopeSpans").and_then(|v| v.as_array()) { + if let Some(scope_spans) = + resource_span.get("scopeSpans").and_then(|v| v.as_array()) + { for scope_span in scope_spans { - if let Some(span_list) = scope_span.get("spans").and_then(|v| v.as_array()) { + if let Some(span_list) = scope_span.get("spans").and_then(|v| v.as_array()) + { spans.extend(span_list.iter()); } } @@ -54,9 +57,9 @@ fn get_string_attr<'a>(span: &'a Value, key: &str) -> Option<&'a str> { span.get("attributes") .and_then(|attrs| attrs.as_array()) .and_then(|attrs| { - attrs.iter().find(|attr| { - attr.get("key").and_then(|k| k.as_str()) == Some(key) - }) + attrs + .iter() + .find(|attr| attr.get("key").and_then(|k| k.as_str()) == Some(key)) }) .and_then(|attr| attr.get("value")) .and_then(|v| v.get("stringValue")) @@ -70,7 +73,10 @@ async fn test_llm_span_contains_basic_attributes() { let mock_collector = MockOtelCollector::start().await; // Create TraceCollector pointing to mock with 500ms flush intervalc - std::env::set_var("OTEL_COLLECTOR_URL", format!("{}/v1/traces", mock_collector.address())); + std::env::set_var( + "OTEL_COLLECTOR_URL", + format!("{}/v1/traces", mock_collector.address()), + ); std::env::set_var("OTEL_TRACING_ENABLED", "true"); std::env::set_var("TRACE_FLUSH_INTERVAL_MS", "500"); let trace_collector = Arc::new(TraceCollector::new(Some(true))); @@ -102,7 +108,10 @@ async fn test_llm_span_contains_basic_attributes() { let span = spans[0]; // Validate HTTP attributes assert_eq!(get_string_attr(span, "http.method"), Some("POST")); - assert_eq!(get_string_attr(span, "http.target"), Some("/v1/chat/completions")); + assert_eq!( + get_string_attr(span, "http.target"), + Some("/v1/chat/completions") + ); // Validate LLM attributes assert_eq!(get_string_attr(span, "llm.model"), Some("gpt-4o")); @@ -115,7 +124,10 @@ async fn test_llm_span_contains_basic_attributes() { #[serial] async fn test_llm_span_contains_tool_information() { let mock_collector = MockOtelCollector::start().await; - std::env::set_var("OTEL_COLLECTOR_URL", format!("{}/v1/traces", mock_collector.address())); + std::env::set_var( + "OTEL_COLLECTOR_URL", + format!("{}/v1/traces", mock_collector.address()), + ); std::env::set_var("OTEL_TRACING_ENABLED", "true"); std::env::set_var("TRACE_FLUSH_INTERVAL_MS", "500"); let trace_collector = Arc::new(TraceCollector::new(Some(true))); @@ -144,19 +156,26 @@ async fn test_llm_span_contains_tool_information() { assert!(tools.unwrap().contains("get_weather(...)")); assert!(tools.unwrap().contains("search_web(...)")); assert!(tools.unwrap().contains("calculate(...)")); - assert!(tools.unwrap().contains('\n'), "Tools should be newline-separated"); + assert!( + tools.unwrap().contains('\n'), + "Tools should be newline-separated" + ); } #[tokio::test] #[serial] async fn test_llm_span_contains_user_message_preview() { let mock_collector = MockOtelCollector::start().await; - std::env::set_var("OTEL_COLLECTOR_URL", format!("{}/v1/traces", mock_collector.address())); + std::env::set_var( + "OTEL_COLLECTOR_URL", + format!("{}/v1/traces", mock_collector.address()), + ); std::env::set_var("OTEL_TRACING_ENABLED", "true"); std::env::set_var("TRACE_FLUSH_INTERVAL_MS", "500"); let trace_collector = Arc::new(TraceCollector::new(Some(true))); - let long_message = "This is a very long user message that should be truncated to 50 characters in the span"; + let long_message = + "This is a very long user message that should be truncated to 50 characters in the span"; let preview = if long_message.len() > 50 { format!("{}...", &long_message[..50]) } else { @@ -187,7 +206,10 @@ async fn test_llm_span_contains_user_message_preview() { #[serial] async fn test_llm_span_contains_time_to_first_token() { let mock_collector = MockOtelCollector::start().await; - std::env::set_var("OTEL_COLLECTOR_URL", format!("{}/v1/traces", mock_collector.address())); + std::env::set_var( + "OTEL_COLLECTOR_URL", + format!("{}/v1/traces", mock_collector.address()), + ); std::env::set_var("OTEL_TRACING_ENABLED", "true"); std::env::set_var("TRACE_FLUSH_INTERVAL_MS", "500"); let trace_collector = Arc::new(TraceCollector::new(Some(true))); @@ -217,7 +239,10 @@ async fn test_llm_span_contains_time_to_first_token() { #[serial] async fn test_llm_span_contains_upstream_path() { let mock_collector = MockOtelCollector::start().await; - std::env::set_var("OTEL_COLLECTOR_URL", format!("{}/v1/traces", mock_collector.address())); + std::env::set_var( + "OTEL_COLLECTOR_URL", + format!("{}/v1/traces", mock_collector.address()), + ); std::env::set_var("OTEL_TRACING_ENABLED", "true"); std::env::set_var("TRACE_FLUSH_INTERVAL_MS", "500"); let trace_collector = Arc::new(TraceCollector::new(Some(true))); @@ -241,7 +266,10 @@ async fn test_llm_span_contains_upstream_path() { // Operation name should show the transformation let name = span.get("name").and_then(|v| v.as_str()); assert!(name.is_some()); - assert!(name.unwrap().contains(">>"), "Operation name should show path transformation"); + assert!( + name.unwrap().contains(">>"), + "Operation name should show path transformation" + ); // Check upstream target attribute let upstream = get_string_attr(span, "http.upstream_target"); @@ -252,7 +280,10 @@ async fn test_llm_span_contains_upstream_path() { #[serial] async fn test_llm_span_multiple_services() { let mock_collector = MockOtelCollector::start().await; - std::env::set_var("OTEL_COLLECTOR_URL", format!("{}/v1/traces", mock_collector.address())); + std::env::set_var( + "OTEL_COLLECTOR_URL", + format!("{}/v1/traces", mock_collector.address()), + ); std::env::set_var("OTEL_TRACING_ENABLED", "true"); std::env::set_var("TRACE_FLUSH_INTERVAL_MS", "500"); let trace_collector = Arc::new(TraceCollector::new(Some(true))); @@ -285,7 +316,10 @@ async fn test_tracing_disabled_produces_no_spans() { let mock_collector = MockOtelCollector::start().await; // Create TraceCollector with tracing DISABLED - std::env::set_var("OTEL_COLLECTOR_URL", format!("{}/v1/traces", mock_collector.address())); + std::env::set_var( + "OTEL_COLLECTOR_URL", + format!("{}/v1/traces", mock_collector.address()), + ); std::env::set_var("OTEL_TRACING_ENABLED", "false"); std::env::set_var("TRACE_FLUSH_INTERVAL_MS", "500"); let trace_collector = Arc::new(TraceCollector::new(Some(false))); @@ -300,5 +334,9 @@ async fn test_tracing_disabled_produces_no_spans() { let payloads = mock_collector.get_traces().await; let all_spans = extract_spans(&payloads); - assert_eq!(all_spans.len(), 0, "No spans should be captured when tracing is disabled"); + assert_eq!( + all_spans.len(), + 0, + "No spans should be captured when tracing is disabled" + ); } diff --git a/crates/common/src/tracing.rs b/crates/common/src/tracing.rs index 60ccca15..145f987e 100644 --- a/crates/common/src/tracing.rs +++ b/crates/common/src/tracing.rs @@ -161,13 +161,12 @@ impl TraceData { } pub fn new_with_service_name(service_name: String) -> Self { - let mut resource_attributes = Vec::new(); - resource_attributes.push(Attribute { + let resource_attributes = vec![Attribute { key: "service.name".to_string(), value: AttributeValue { string_value: Some(service_name), }, - }); + }]; let resource = Resource { attributes: resource_attributes, @@ -194,7 +193,9 @@ impl TraceData { pub fn add_span(&mut self, span: Span) { if self.resource_spans.is_empty() { - let resource = Resource { attributes: Vec::new() }; + let resource = Resource { + attributes: Vec::new(), + }; let scope_span = ScopeSpan { scope: Scope { name: "default".to_string(), diff --git a/crates/hermesllm/src/apis/amazon_bedrock.rs b/crates/hermesllm/src/apis/amazon_bedrock.rs index 3c953850..dbada283 100644 --- a/crates/hermesllm/src/apis/amazon_bedrock.rs +++ b/crates/hermesllm/src/apis/amazon_bedrock.rs @@ -66,7 +66,7 @@ impl ApiDefinition for AmazonBedrockApi { /// Amazon Bedrock Converse request #[skip_serializing_none] -#[derive(Serialize, Deserialize, Debug, Clone)] +#[derive(Serialize, Deserialize, Debug, Clone, Default)] pub struct ConverseRequest { /// The model ID or ARN to invoke pub model_id: String, @@ -91,7 +91,7 @@ pub struct ConverseRequest { pub additional_model_response_field_paths: Option>, /// Performance configuration #[serde(rename = "performanceConfig")] - pub performance_config: Option, + pub performance_config: Option, /// Prompt variables for Prompt management #[serde(rename = "promptVariables")] pub prompt_variables: Option>, @@ -105,26 +105,6 @@ pub struct ConverseRequest { pub stream: bool, } -impl Default for ConverseRequest { - fn default() -> Self { - Self { - model_id: String::new(), - messages: None, - system: None, - inference_config: None, - tool_config: None, - guardrail_config: None, - additional_model_request_fields: None, - additional_model_response_field_paths: None, - performance_config: None, - prompt_variables: None, - request_metadata: None, - metadata: None, - stream: false, - } - } -} - /// Amazon Bedrock ConverseStream request (same structure as Converse) pub type ConverseStreamRequest = ConverseRequest; @@ -204,8 +184,8 @@ impl ProviderRequest for ConverseRequest { self.tool_config.as_ref()?.tools.as_ref().map(|tools| { tools .iter() - .filter_map(|tool| match tool { - Tool::ToolSpec { tool_spec } => Some(tool_spec.name.clone()), + .map(|tool| match tool { + Tool::ToolSpec { tool_spec } => tool_spec.name.clone(), }) .collect() }) @@ -242,17 +222,14 @@ impl ProviderRequest for ConverseRequest { // Add system messages if present if let Some(system) = &self.system { for sys_block in system { - match sys_block { - SystemContentBlock::Text { text } => { - openai_messages.push(Message { - role: Role::System, - content: MessageContent::Text(text.clone()), - name: None, - tool_calls: None, - tool_call_id: None, - }); - } - _ => {} // Skip other system content types + if let SystemContentBlock::Text { text } = sys_block { + openai_messages.push(Message { + role: Role::System, + content: MessageContent::Text(text.clone()), + name: None, + tool_calls: None, + tool_call_id: None, + }); } } } @@ -266,7 +243,9 @@ impl ProviderRequest for ConverseRequest { }; // Extract text from content blocks - let content = msg.content.iter() + let content = msg + .content + .iter() .filter_map(|block| { if let ContentBlock::Text { text } = block { Some(text.clone()) @@ -311,16 +290,14 @@ impl ProviderRequest for ConverseRequest { _ => continue, }; - let content = if let crate::apis::openai::MessageContent::Text(text) = &msg.content { - vec![ContentBlock::Text { text: text.clone() }] - } else { - vec![] - }; + let content = + if let crate::apis::openai::MessageContent::Text(text) = &msg.content { + vec![ContentBlock::Text { text: text.clone() }] + } else { + vec![] + }; - bedrock_messages.push(crate::apis::amazon_bedrock::Message { - role, - content, - }); + bedrock_messages.push(crate::apis::amazon_bedrock::Message { role, content }); } _ => {} } @@ -369,7 +346,7 @@ pub enum ConverseStreamEvent { ContentBlockDelta(ContentBlockDeltaEvent), ContentBlockStop(ContentBlockStopEvent), MessageStop(MessageStopEvent), - Metadata(ConverseStreamMetadataEvent), + Metadata(Box), // Error events InternalServerException(BedrockException), ModelStreamErrorException(BedrockException), @@ -1063,7 +1040,7 @@ impl TryFrom<&aws_smithy_eventstream::frame::DecodedFrame> for ConverseStreamEve "metadata" => { let event: ConverseStreamMetadataEvent = serde_json::from_slice(payload).map_err(BedrockError::Serialization)?; - Ok(ConverseStreamEvent::Metadata(event)) + Ok(ConverseStreamEvent::Metadata(Box::new(event))) } unknown => Err(BedrockError::Validation { message: format!("Unknown event type: {}", unknown), @@ -1106,10 +1083,10 @@ impl TryFrom<&aws_smithy_eventstream::frame::DecodedFrame> for ConverseStreamEve } } -impl Into for ConverseStreamEvent { - fn into(self) -> String { - let transformed_json = serde_json::to_string(&self).unwrap_or_default(); - let event_type = match &self { +impl From for String { + fn from(val: ConverseStreamEvent) -> String { + let transformed_json = serde_json::to_string(&val).unwrap_or_default(); + let event_type = match &val { ConverseStreamEvent::MessageStart { .. } => "message_start", ConverseStreamEvent::ContentBlockStart { .. } => "content_block_start", ConverseStreamEvent::ContentBlockDelta { .. } => "content_block_delta", diff --git a/crates/hermesllm/src/apis/anthropic.rs b/crates/hermesllm/src/apis/anthropic.rs index 2e73e1c2..ed3317ce 100644 --- a/crates/hermesllm/src/apis/anthropic.rs +++ b/crates/hermesllm/src/apis/anthropic.rs @@ -516,9 +516,9 @@ impl ProviderRequest for MessagesRequest { } fn get_tool_names(&self) -> Option> { - self.tools.as_ref().map(|tools| { - tools.iter().map(|tool| tool.name.clone()).collect() - }) + self.tools + .as_ref() + .map(|tools| tools.iter().map(|tool| tool.name.clone()).collect()) } fn to_bytes(&self) -> Result, ProviderRequestError> { @@ -529,7 +529,7 @@ impl ProviderRequest for MessagesRequest { } fn metadata(&self) -> &Option> { - return &self.metadata; + &self.metadata } fn remove_metadata_key(&mut self, key: &str) -> bool { @@ -581,7 +581,8 @@ impl ProviderRequest for MessagesRequest { // Set system prompt if there are system messages if !system_messages.is_empty() { // Combine all system messages into one - let system_text = system_messages.iter() + let system_text = system_messages + .iter() .filter_map(|msg| { if let crate::apis::openai::MessageContent::Text(text) = &msg.content { Some(text.as_str()) @@ -592,14 +593,15 @@ impl ProviderRequest for MessagesRequest { .collect::>() .join("\n"); - self.system = Some(crate::apis::anthropic::MessagesSystemPrompt::Single(system_text)); + self.system = Some(crate::apis::anthropic::MessagesSystemPrompt::Single( + system_text, + )); } // Convert regular messages - self.messages = regular_messages.iter() - .filter_map(|msg| { - msg.clone().try_into().ok() - }) + self.messages = regular_messages + .iter() + .filter_map(|msg| msg.clone().try_into().ok()) .collect(); } } @@ -1298,7 +1300,7 @@ mod tests { }, { "type": "text", - "text": "\nYou are an interactive CLI tool that helps users with software engineering tasks. Use the instructions below and the tools available to you to assist the user.\n\nIMPORTANT: Assist with defensive security tasks only. Refuse to create, modify, or improve code that may be used maliciously. Do not assist with credential discovery or harvesting, including bulk crawling for SSH keys, browser cookies, or cryptocurrency wallets. Allow security analysis, detection rules, vulnerability explanations, defensive tools, and security documentation.\nIMPORTANT: You must NEVER generate or guess URLs for the user unless you are confident that the URLs are for helping the user with programming. You may use URLs provided by the user in their messages or local files.\n\nIf the user asks for help or wants to give feedback inform them of the following: \n- /help: Get help with using Claude Code\n- To give feedback, users should report the issue at https://github.com/anthropics/claude-code/issues\n\nWhen⁠ the user directly asks about Claude Code (eg. \"can Claude Code do...\", \"does Claude Code have...\"), or asks in second person (eg. \"are you able...\", \"can you do...\"), or asks how to use a specific Claude Code feature (eg. implement a hook, or write a slash command), use the WebFetch tool to gather information to answer the question from Claude Code docs. The list of available docs is available at https://docs.claude.com/en/docs/claude-code/claude_code_docs_map.md.\n\n#⁠ Tone and style\nYou should be concise, direct, and to the point.\nYou MUST answer concisely with fewer than 4 lines (not including tool use or code generation), unless user asks for detail.\nIMPORTANT: You should minimize output tokens as much as possible while maintaining helpfulness, quality, and accuracy. Only address the specific task at hand, avoiding tangential information unless absolutely critical for completing the request. If you can answer in 1-3 sentences or a short paragraph, please do.\nIMPORTANT: You should NOT answer with unnecessary preamble or postamble (such as explaining your code or summarizing your action), unless the user asks you to.\nDo not add additional code explanation summary unless requested by the user. After working on a file, just stop, rather than providing an explanation of what you did.\nAnswer the user's question directly, avoiding any elaboration, explanation, introduction, conclusion, or excessive details. One word answers are best. You MUST avoid text before/after your response, such as \"The answer is .\", \"Here is the content of the file...\" or \"Based on the information provided, the answer is...\" or \"Here is what I will do next...\".\n\nHere are some examples to demonstrate appropriate verbosity:\n\nuser: 2 + 2\nassistant: 4\n\n\n\nuser: what is 2+2?\nassistant: 4\n\n\n\nuser: is 11 a prime number?\nassistant: Yes\n\n\n\nuser: what command should I run to list files in the current directory?\nassistant: ls\n\n\n\nuser: what command should I run to watch files in the current directory?\nassistant: [runs ls to list the files in the current directory, then read docs/commands in the relevant file to find out how to watch files]\nnpm run dev\n\n\n\nuser: How many golf balls fit inside a jetta?\nassistant: 150000\n\n\n\nuser: what files are in the directory src/?\nassistant: [runs ls and sees foo.c, bar.c, baz.c]\nuser: which file contains the implementation of foo?\nassistant: src/foo.c\n\nWhen you run a non-trivial bash command, you should explain what the command does and why you are running it, to make sure the user understands what you are doing (this is especially important when you are running a command that will make changes to the user's system).\nRemember that your output will be displayed on a command line interface. Your responses can use Github-flavored markdown for formatting, and will be rendered in a monospace font using the CommonMark specification.\nOutput text to communicate with the user; all text you output outside of tool use is displayed to the user. Only use tools to complete tasks. Never use tools like Bash or code comments as means to communicate with the user during the session.\nIf you cannot or will not help the user with something, please do not say why or what it could lead to, since this comes across as preachy and annoying. Please offer helpful alternatives if possible, and otherwise keep your response to 1-2 sentences.\nOnly use emojis if the user explicitly requests it. Avoid using emojis in all communication unless asked.\nIMPORTANT: Keep your responses short, since they will be displayed on a command line interface.\n\n# Proactiveness\nYou are allowed to be proactive, but only when the user asks you to do something. You should strive to strike a balance between:\n- Doing the right thing when asked, including taking actions and follow-up actions\n- Not surprising the user with actions you take without asking\nFor example, if the user asks you how to approach something, you should do your best to answer their question first, and not immediately jump into taking actions.\n\n# Professional objectivity\nPrioritize technical accuracy and truthfulness over validating the user's beliefs. Focus on facts and problem-solving, providing direct, objective technical info without any unnecessary superlatives, praise, or emotional validation. It is best for the user if Claude honestly applies the same rigorous standards to all ideas and disagrees when necessary, even if it may not be what the user wants to hear. Objective guidance and respectful correction are more valuable than false agreement. Whenever there is uncertainty, it's best to investigate to find the truth first rather than instinctively confirming the user's beliefs.\n\n# Following conventions\nWhen making changes to files, first understand the file's code conventions. Mimic code style, use existing libraries and utilities, and follow existing patterns.\n- NEVER assume that a given library is available, even if it is well known. Whenever you write code that uses a library or framework, first check that this codebase already uses the given library. For example, you might look at neighboring files, or check the package.json (or cargo.toml, and so on depending on the language).\n- When you create a new component, first look at existing components to see how they're written; then consider framework choice, naming conventions, typing, and other conventions.\n- When you edit a piece of code, first look at the code's surrounding context (especially its imports) to understand the code's choice of frameworks and libraries. Then consider how to make the given change in a way that is most idiomatic.\n- Always follow security best practices. Never introduce code that exposes or logs secrets and keys. Never commit secrets or keys to the repository.\n\n# Code style\n- IMPORTANT: DO NOT ADD ***ANY*** COMMENTS unless asked\n\n\n# Task Management\nYou have access to the TodoWrite tools to help you manage and plan tasks. Use these tools VERY frequently to ensure that you are tracking your tasks and giving the user visibility into your progress.\nThese tools are also EXTREMELY helpful for planning tasks, and for breaking down larger complex tasks into smaller steps. If you do not use this tool when planning, you may forget to do important tasks - and that is unacceptable.\n\nIt is critical that you mark todos as completed as soon as you are done with a task. Do not batch up multiple tasks before marking them as completed.\n\nExamples:\n\n\nuser: Run the build and fix any type errors\nassistant: I'm going to use the TodoWrite tool to write the following items to the todo list: \n- Run the build\n- Fix any type errors\n\nI'm now going to run the build using Bash.\n\nLooks like I found 10 type errors. I'm going to use the TodoWrite tool to write 10 items to the todo list.\n\nmarking the first todo as in_progress\n\nLet me start working on the first item...\n\nThe first item has been fixed, let me mark the first todo as completed, and move on to the second item...\n..\n..\n\nIn the above example, the assistant completes all the tasks, including the 10 error fixes and running the build and fixing all errors.\n\n\nuser: Help me write a new feature that allows users to track their usage metrics and export them to various formats\n\nassistant: I'll help you implement a usage metrics tracking and export feature. Let me first use the TodoWrite tool to plan this task.\nAdding the following todos to the todo list:\n1. Research existing metrics tracking in the codebase\n2. Design the metrics collection system\n3. Implement core metrics tracking functionality\n4. Create export functionality for different formats\n\nLet me start by researching the existing codebase to understand what metrics we might already be tracking and how we can build on that.\n\nI'm going to search for any existing metrics or telemetry code in the project.\n\nI've found some existing telemetry code. Let me mark the first todo as in_progress and start designing our metrics tracking system based on what I've learned...\n\n[Assistant continues implementing the feature step by step, marking todos as in_progress and completed as they go]\n\n\n\nUsers may configure 'hooks', shell commands that execute in response to events like tool calls, in settings. Treat feedback from hooks, including , as coming from the user. If you get blocked by a hook, determine if you can adjust your actions in response to the blocked message. If not, ask the user to check their hooks configuration.\n\n# Doing tasks\nThe user will primarily request you perform software engineering tasks. This includes solving bugs, adding new functionality, refactoring code, explaining code, and more. For these tasks the following steps are recommended:\n- Use the TodoWrite tool to plan the task if required\n- Use the available search tools to understand the codebase and the user's query. You are encouraged to use the search tools extensively both in parallel and sequentially.\n- Implement the solution using all tools available to you\n- Verify the solution if possible with tests. NEVER assume specific test framework or test script. Check the README or search codebase to determine the testing approach.\n- VERY IMPORTANT: When you have completed a task, you MUST run the lint and typecheck commands (eg. npm run lint, npm run typecheck, ruff, etc.) with Bash if they were provided to you to ensure your code is correct. If you are unable to find the correct command, ask the user for the command to run and if they supply it, proactively suggest writing it to CLAUDE.md so that you will know to run it next time.\nNEVER commit changes unless the user explicitly asks you to. It is VERY IMPORTANT to only commit when explicitly asked, otherwise the user will feel that you are being too proactive.\n\n- Tool results and user messages may include tags. tags contain useful information and reminders. They are automatically added by the system, and bear no direct relation to the specific tool results or user messages in which they appear.\n\n\n# Tool usage policy\n- When doing file search, prefer to use the Task tool in order to reduce context usage.\n- You should proactively use the Task tool with specialized agents when the task at hand matches the agent's description.\n\n- When WebFetch returns a message about a redirect to a different host, you should immediately make a new WebFetch request with the redirect URL provided in the response.\n- You have the capability to call multiple tools in a single response. When multiple independent pieces of information are requested, batch your tool calls together for optimal performance. When making multiple bash tool calls, you MUST send a single message with multiple tools calls to run the calls in parallel. For example, if you need to run \"git status\" and \"git diff\", send a single message with two tool calls to run the calls in parallel.\n- If the user specifies that they want you to run tools \"in parallel\", you MUST send a single message with multiple tool use content blocks. For example, if you need to launch multiple agents in parallel, send a single message with multiple Task tool calls.\n\n\n\n\nHere is useful information about the environment you are running in:\n\nWorking directory: /Users/salmanparacha/arch/crates/llm_gateway\nIs directory a git repo: Yes\nPlatform: darwin\nOS Version: Darwin 25.0.0\nToday's date: 2025-09-25\n\nYou are powered by the model named Sonnet 4. The exact model ID is claude-sonnet-4-20250514.\n\nAssistant knowledge cutoff is January 2025.\n\n\nIMPORTANT: Assist with defensive security tasks only. Refuse to create, modify, or improve code that may be used maliciously. Do not assist with credential discovery or harvesting, including bulk crawling for 2025-09-25T22:19:13.499582010Z SSH keys, browser cookies, or cryptocurrency wallets. Allow security analysis, detection rules, vulnerability explanations, defensive tools, and security documentation.\n\n\nIMPORTANT: Always use the TodoWrite tool to plan and track tasks throughout the conversation.\n\n# Code References\n\nWhen referencing specific functions or pieces of code include the pattern `file_path:line_number` to allow the user to easily navigate to the source code location.\n\n\nuser: Where are errors from the client handled?\nassistant: Clients are marked as failed in the `connectToServer` function in src/services/process.ts:712.\n\n", + "text": "\nYou are an interactive CLI tool that helps users with software engineering tasks. Use the instructions below and the tools available to you to assist the user.\n\nIMPORTANT: Assist with defensive security tasks only. Refuse to create, modify, or improve code that may be used maliciously. Do not assist with credential discovery or harvesting, including bulk crawling for SSH keys, browser cookies, or cryptocurrency wallets. Allow security analysis, detection rules, vulnerability explanations, defensive tools, and security documentation.\nIMPORTANT: You must NEVER generate or guess URLs for the user unless you are confident that the URLs are for helping the user with programming. You may use URLs provided by the user in their messages or local files.\n\nIf the user asks for help or wants to give feedback inform them of the following: \n- /help: Get help with using Claude Code\n- To give feedback, users should report the issue at https://github.com/anthropics/claude-code/issues\n\nWhen\u{2060} the user directly asks about Claude Code (eg. \"can Claude Code do...\", \"does Claude Code have...\"), or asks in second person (eg. \"are you able...\", \"can you do...\"), or asks how to use a specific Claude Code feature (eg. implement a hook, or write a slash command), use the WebFetch tool to gather information to answer the question from Claude Code docs. The list of available docs is available at https://docs.claude.com/en/docs/claude-code/claude_code_docs_map.md.\n\n#\u{2060} Tone and style\nYou should be concise, direct, and to the point.\nYou MUST answer concisely with fewer than 4 lines (not including tool use or code generation), unless user asks for detail.\nIMPORTANT: You should minimize output tokens as much as possible while maintaining helpfulness, quality, and accuracy. Only address the specific task at hand, avoiding tangential information unless absolutely critical for completing the request. If you can answer in 1-3 sentences or a short paragraph, please do.\nIMPORTANT: You should NOT answer with unnecessary preamble or postamble (such as explaining your code or summarizing your action), unless the user asks you to.\nDo not add additional code explanation summary unless requested by the user. After working on a file, just stop, rather than providing an explanation of what you did.\nAnswer the user's question directly, avoiding any elaboration, explanation, introduction, conclusion, or excessive details. One word answers are best. You MUST avoid text before/after your response, such as \"The answer is .\", \"Here is the content of the file...\" or \"Based on the information provided, the answer is...\" or \"Here is what I will do next...\".\n\nHere are some examples to demonstrate appropriate verbosity:\n\nuser: 2 + 2\nassistant: 4\n\n\n\nuser: what is 2+2?\nassistant: 4\n\n\n\nuser: is 11 a prime number?\nassistant: Yes\n\n\n\nuser: what command should I run to list files in the current directory?\nassistant: ls\n\n\n\nuser: what command should I run to watch files in the current directory?\nassistant: [runs ls to list the files in the current directory, then read docs/commands in the relevant file to find out how to watch files]\nnpm run dev\n\n\n\nuser: How many golf balls fit inside a jetta?\nassistant: 150000\n\n\n\nuser: what files are in the directory src/?\nassistant: [runs ls and sees foo.c, bar.c, baz.c]\nuser: which file contains the implementation of foo?\nassistant: src/foo.c\n\nWhen you run a non-trivial bash command, you should explain what the command does and why you are running it, to make sure the user understands what you are doing (this is especially important when you are running a command that will make changes to the user's system).\nRemember that your output will be displayed on a command line interface. Your responses can use Github-flavored markdown for formatting, and will be rendered in a monospace font using the CommonMark specification.\nOutput text to communicate with the user; all text you output outside of tool use is displayed to the user. Only use tools to complete tasks. Never use tools like Bash or code comments as means to communicate with the user during the session.\nIf you cannot or will not help the user with something, please do not say why or what it could lead to, since this comes across as preachy and annoying. Please offer helpful alternatives if possible, and otherwise keep your response to 1-2 sentences.\nOnly use emojis if the user explicitly requests it. Avoid using emojis in all communication unless asked.\nIMPORTANT: Keep your responses short, since they will be displayed on a command line interface.\n\n# Proactiveness\nYou are allowed to be proactive, but only when the user asks you to do something. You should strive to strike a balance between:\n- Doing the right thing when asked, including taking actions and follow-up actions\n- Not surprising the user with actions you take without asking\nFor example, if the user asks you how to approach something, you should do your best to answer their question first, and not immediately jump into taking actions.\n\n# Professional objectivity\nPrioritize technical accuracy and truthfulness over validating the user's beliefs. Focus on facts and problem-solving, providing direct, objective technical info without any unnecessary superlatives, praise, or emotional validation. It is best for the user if Claude honestly applies the same rigorous standards to all ideas and disagrees when necessary, even if it may not be what the user wants to hear. Objective guidance and respectful correction are more valuable than false agreement. Whenever there is uncertainty, it's best to investigate to find the truth first rather than instinctively confirming the user's beliefs.\n\n# Following conventions\nWhen making changes to files, first understand the file's code conventions. Mimic code style, use existing libraries and utilities, and follow existing patterns.\n- NEVER assume that a given library is available, even if it is well known. Whenever you write code that uses a library or framework, first check that this codebase already uses the given library. For example, you might look at neighboring files, or check the package.json (or cargo.toml, and so on depending on the language).\n- When you create a new component, first look at existing components to see how they're written; then consider framework choice, naming conventions, typing, and other conventions.\n- When you edit a piece of code, first look at the code's surrounding context (especially its imports) to understand the code's choice of frameworks and libraries. Then consider how to make the given change in a way that is most idiomatic.\n- Always follow security best practices. Never introduce code that exposes or logs secrets and keys. Never commit secrets or keys to the repository.\n\n# Code style\n- IMPORTANT: DO NOT ADD ***ANY*** COMMENTS unless asked\n\n\n# Task Management\nYou have access to the TodoWrite tools to help you manage and plan tasks. Use these tools VERY frequently to ensure that you are tracking your tasks and giving the user visibility into your progress.\nThese tools are also EXTREMELY helpful for planning tasks, and for breaking down larger complex tasks into smaller steps. If you do not use this tool when planning, you may forget to do important tasks - and that is unacceptable.\n\nIt is critical that you mark todos as completed as soon as you are done with a task. Do not batch up multiple tasks before marking them as completed.\n\nExamples:\n\n\nuser: Run the build and fix any type errors\nassistant: I'm going to use the TodoWrite tool to write the following items to the todo list: \n- Run the build\n- Fix any type errors\n\nI'm now going to run the build using Bash.\n\nLooks like I found 10 type errors. I'm going to use the TodoWrite tool to write 10 items to the todo list.\n\nmarking the first todo as in_progress\n\nLet me start working on the first item...\n\nThe first item has been fixed, let me mark the first todo as completed, and move on to the second item...\n..\n..\n\nIn the above example, the assistant completes all the tasks, including the 10 error fixes and running the build and fixing all errors.\n\n\nuser: Help me write a new feature that allows users to track their usage metrics and export them to various formats\n\nassistant: I'll help you implement a usage metrics tracking and export feature. Let me first use the TodoWrite tool to plan this task.\nAdding the following todos to the todo list:\n1. Research existing metrics tracking in the codebase\n2. Design the metrics collection system\n3. Implement core metrics tracking functionality\n4. Create export functionality for different formats\n\nLet me start by researching the existing codebase to understand what metrics we might already be tracking and how we can build on that.\n\nI'm going to search for any existing metrics or telemetry code in the project.\n\nI've found some existing telemetry code. Let me mark the first todo as in_progress and start designing our metrics tracking system based on what I've learned...\n\n[Assistant continues implementing the feature step by step, marking todos as in_progress and completed as they go]\n\n\n\nUsers may configure 'hooks', shell commands that execute in response to events like tool calls, in settings. Treat feedback from hooks, including , as coming from the user. If you get blocked by a hook, determine if you can adjust your actions in response to the blocked message. If not, ask the user to check their hooks configuration.\n\n# Doing tasks\nThe user will primarily request you perform software engineering tasks. This includes solving bugs, adding new functionality, refactoring code, explaining code, and more. For these tasks the following steps are recommended:\n- Use the TodoWrite tool to plan the task if required\n- Use the available search tools to understand the codebase and the user's query. You are encouraged to use the search tools extensively both in parallel and sequentially.\n- Implement the solution using all tools available to you\n- Verify the solution if possible with tests. NEVER assume specific test framework or test script. Check the README or search codebase to determine the testing approach.\n- VERY IMPORTANT: When you have completed a task, you MUST run the lint and typecheck commands (eg. npm run lint, npm run typecheck, ruff, etc.) with Bash if they were provided to you to ensure your code is correct. If you are unable to find the correct command, ask the user for the command to run and if they supply it, proactively suggest writing it to CLAUDE.md so that you will know to run it next time.\nNEVER commit changes unless the user explicitly asks you to. It is VERY IMPORTANT to only commit when explicitly asked, otherwise the user will feel that you are being too proactive.\n\n- Tool results and user messages may include tags. tags contain useful information and reminders. They are automatically added by the system, and bear no direct relation to the specific tool results or user messages in which they appear.\n\n\n# Tool usage policy\n- When doing file search, prefer to use the Task tool in order to reduce context usage.\n- You should proactively use the Task tool with specialized agents when the task at hand matches the agent's description.\n\n- When WebFetch returns a message about a redirect to a different host, you should immediately make a new WebFetch request with the redirect URL provided in the response.\n- You have the capability to call multiple tools in a single response. When multiple independent pieces of information are requested, batch your tool calls together for optimal performance. When making multiple bash tool calls, you MUST send a single message with multiple tools calls to run the calls in parallel. For example, if you need to run \"git status\" and \"git diff\", send a single message with two tool calls to run the calls in parallel.\n- If the user specifies that they want you to run tools \"in parallel\", you MUST send a single message with multiple tool use content blocks. For example, if you need to launch multiple agents in parallel, send a single message with multiple Task tool calls.\n\n\n\n\nHere is useful information about the environment you are running in:\n\nWorking directory: /Users/salmanparacha/arch/crates/llm_gateway\nIs directory a git repo: Yes\nPlatform: darwin\nOS Version: Darwin 25.0.0\nToday's date: 2025-09-25\n\nYou are powered by the model named Sonnet 4. The exact model ID is claude-sonnet-4-20250514.\n\nAssistant knowledge cutoff is January 2025.\n\n\nIMPORTANT: Assist with defensive security tasks only. Refuse to create, modify, or improve code that may be used maliciously. Do not assist with credential discovery or harvesting, including bulk crawling for 2025-09-25T22:19:13.499582010Z SSH keys, browser cookies, or cryptocurrency wallets. Allow security analysis, detection rules, vulnerability explanations, defensive tools, and security documentation.\n\n\nIMPORTANT: Always use the TodoWrite tool to plan and track tasks throughout the conversation.\n\n# Code References\n\nWhen referencing specific functions or pieces of code include the pattern `file_path:line_number` to allow the user to easily navigate to the source code location.\n\n\nuser: Where are errors from the client handled?\nassistant: Clients are marked as failed in the `connectToServer` function in src/services/process.ts:712.\n\n", "cache_control": { "type": "ephemeral" } diff --git a/crates/hermesllm/src/apis/openai.rs b/crates/hermesllm/src/apis/openai.rs index 79154d39..834c33ec 100644 --- a/crates/hermesllm/src/apis/openai.rs +++ b/crates/hermesllm/src/apis/openai.rs @@ -286,7 +286,6 @@ pub struct ImageUrl { } /// A single message in a chat conversation - /// A tool call made by the assistant #[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] pub struct ToolCall { @@ -388,7 +387,7 @@ pub enum StaticContentType { /// Chat completions API response #[skip_serializing_none] -#[derive(Serialize, Deserialize, Debug, Clone)] +#[derive(Serialize, Deserialize, Debug, Clone, Default)] pub struct ChatCompletionsResponse { pub id: String, pub object: Option, @@ -402,22 +401,6 @@ pub struct ChatCompletionsResponse { pub metadata: Option>, } -impl Default for ChatCompletionsResponse { - fn default() -> Self { - ChatCompletionsResponse { - id: String::new(), - object: None, - created: 0, - model: String::new(), - choices: vec![], - usage: Usage::default(), - system_fingerprint: None, - service_tier: None, - metadata: None, - } - } -} - /// Finish reason for completion #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] #[serde(rename_all = "snake_case")] @@ -431,7 +414,7 @@ pub enum FinishReason { /// Token usage information #[skip_serializing_none] -#[derive(Serialize, Deserialize, Debug, Clone)] +#[derive(Serialize, Deserialize, Debug, Clone, Default)] pub struct Usage { pub prompt_tokens: u32, pub completion_tokens: u32, @@ -440,18 +423,6 @@ pub struct Usage { pub completion_tokens_details: Option, } -impl Default for Usage { - fn default() -> Self { - Usage { - prompt_tokens: 0, - completion_tokens: 0, - total_tokens: 0, - prompt_tokens_details: None, - completion_tokens_details: None, - } - } -} - /// Detailed breakdown of prompt tokens #[skip_serializing_none] #[derive(Serialize, Deserialize, Debug, Clone)] @@ -472,7 +443,7 @@ pub struct CompletionTokensDetails { /// A single choice in the response #[skip_serializing_none] -#[derive(Serialize, Deserialize, Debug, Clone)] +#[derive(Serialize, Deserialize, Debug, Clone, Default)] pub struct Choice { pub index: u32, pub message: ResponseMessage, @@ -480,17 +451,6 @@ pub struct Choice { pub logprobs: Option, } -impl Default for Choice { - fn default() -> Self { - Choice { - index: 0, - message: ResponseMessage::default(), - finish_reason: None, - logprobs: None, - } - } -} - // ============================================================================ // STREAMING API TYPES // ============================================================================ @@ -608,7 +568,6 @@ pub enum OpenAIError { // ============================================================================ /// Trait Implementations /// =========================================================================== - /// Parameterized conversion for ChatCompletionsRequest impl TryFrom<&[u8]> for ChatCompletionsRequest { type Error = OpenAIStreamError; @@ -721,7 +680,7 @@ impl ProviderRequest for ChatCompletionsRequest { } fn metadata(&self) -> &Option> { - return &self.metadata; + &self.metadata } fn remove_metadata_key(&mut self, key: &str) -> bool { diff --git a/crates/hermesllm/src/apis/openai_responses.rs b/crates/hermesllm/src/apis/openai_responses.rs index 6afe9f09..e49173dc 100644 --- a/crates/hermesllm/src/apis/openai_responses.rs +++ b/crates/hermesllm/src/apis/openai_responses.rs @@ -1,7 +1,7 @@ -use std::collections::HashMap; +use crate::providers::request::{ProviderRequest, ProviderRequestError}; use serde::{Deserialize, Serialize}; use serde_with::skip_serializing_none; -use crate::providers::request::{ProviderRequest, ProviderRequestError}; +use std::collections::HashMap; impl TryFrom<&[u8]> for ResponsesAPIRequest { type Error = serde_json::Error; @@ -172,18 +172,14 @@ pub enum MessageRole { #[serde(tag = "type", rename_all = "snake_case")] pub enum InputContent { /// Text input - InputText { - text: String, - }, + InputText { text: String }, /// Image input via URL InputImage { image_url: String, detail: Option, }, /// File input via URL - InputFile { - file_url: String, - }, + InputFile { file_url: String }, /// Audio input InputAudio { data: Option, @@ -222,9 +218,7 @@ pub struct TextConfig { pub enum TextFormat { Text, JsonObject, - JsonSchema { - json_schema: serde_json::Value, - }, + JsonSchema { json_schema: serde_json::Value }, } /// Reasoning effort levels @@ -608,9 +602,7 @@ pub enum OutputContent { transcript: Option, }, /// Refusal output - Refusal { - refusal: String, - }, + Refusal { refusal: String }, } /// Annotations for output text @@ -663,13 +655,9 @@ pub struct FileSearchResult { #[serde(tag = "type", rename_all = "snake_case")] pub enum CodeInterpreterOutput { /// Text output - Text { - text: String, - }, + Text { text: String }, /// Image output - Image { - image: String, - }, + Image { image: String }, } /// Response usage statistics @@ -951,9 +939,7 @@ pub enum ResponsesAPIStreamEvent { }, /// Done event (end of stream) - Done { - sequence_number: i32, - }, + Done { sequence_number: i32 }, } // ============================================================================ @@ -1052,12 +1038,19 @@ impl ProviderRequest for ResponsesAPIRequest { MessageContent::Text(text) => text.clone(), MessageContent::Items(content_items) => { content_items.iter().fold(String::new(), |acc, content| { - acc + " " + &match content { - InputContent::InputText { text } => text.clone(), - InputContent::InputImage { .. } => "[Image]".to_string(), - InputContent::InputFile { .. } => "[File]".to_string(), - InputContent::InputAudio { .. } => "[Audio]".to_string(), - } + acc + " " + + &match content { + InputContent::InputText { text } => text.clone(), + InputContent::InputImage { .. } => { + "[Image]".to_string() + } + InputContent::InputFile { .. } => { + "[File]".to_string() + } + InputContent::InputAudio { .. } => { + "[Audio]".to_string() + } + } }) } }; @@ -1082,11 +1075,9 @@ impl ProviderRequest for ResponsesAPIRequest { match &msg.content { MessageContent::Text(text) => Some(text.clone()), MessageContent::Items(content_items) => { - content_items.iter().find_map(|content| { - match content { - InputContent::InputText { text } => Some(text.clone()), - _ => None, - } + content_items.iter().find_map(|content| match content { + InputContent::InputText { text } => Some(text.clone()), + _ => None, }) } } @@ -1176,9 +1167,12 @@ impl ProviderRequest for ResponsesAPIRequest { // Extract text from message content let content = match &msg.content { - crate::apis::openai_responses::MessageContent::Text(text) => text.clone(), + crate::apis::openai_responses::MessageContent::Text(text) => { + text.clone() + } crate::apis::openai_responses::MessageContent::Items(items) => { - items.iter() + items + .iter() .filter_map(|c| { if let InputContent::InputText { text } = c { Some(text.clone()) @@ -1214,7 +1208,8 @@ impl ProviderRequest for ResponsesAPIRequest { fn set_messages(&mut self, messages: &[crate::apis::openai::Message]) { // For ResponsesAPI, we need to convert messages back to input format // Extract system messages as instructions - let system_text = messages.iter() + let system_text = messages + .iter() .filter(|msg| msg.role == crate::apis::openai::Role::System) .filter_map(|msg| { if let crate::apis::openai::MessageContent::Text(text) = &msg.content { @@ -1233,23 +1228,27 @@ impl ProviderRequest for ResponsesAPIRequest { // Convert user/assistant messages to InputParam // For simplicity, we'll use the last user message as the input // or combine all non-system messages - let input_messages: Vec<_> = messages.iter() + let input_messages: Vec<_> = messages + .iter() .filter(|msg| msg.role != crate::apis::openai::Role::System) .collect(); if !input_messages.is_empty() { // If there's only one message, use Text format if input_messages.len() == 1 { - if let crate::apis::openai::MessageContent::Text(text) = &input_messages[0].content { + if let crate::apis::openai::MessageContent::Text(text) = &input_messages[0].content + { self.input = crate::apis::openai_responses::InputParam::Text(text.clone()); } } else { // Multiple messages - combine them as text for now // A more sophisticated approach would use InputParam::Items - let combined_text = input_messages.iter() + let combined_text = input_messages + .iter() .filter_map(|msg| { if let crate::apis::openai::MessageContent::Text(text) = &msg.content { - Some(format!("{}: {}", + Some(format!( + "{}: {}", match msg.role { crate::apis::openai::Role::User => "User", crate::apis::openai::Role::Assistant => "Assistant", @@ -1274,10 +1273,10 @@ impl ProviderRequest for ResponsesAPIRequest { // Into Implementation for SSE Formatting // ============================================================================ -impl Into for ResponsesAPIStreamEvent { - fn into(self) -> String { - let transformed_json = serde_json::to_string(&self).unwrap_or_default(); - let event_type = match &self { +impl From for String { + fn from(val: ResponsesAPIStreamEvent) -> Self { + let transformed_json = serde_json::to_string(&val).unwrap_or_default(); + let event_type = match &val { ResponsesAPIStreamEvent::ResponseCreated { .. } => "response.created", ResponsesAPIStreamEvent::ResponseInProgress { .. } => "response.in_progress", ResponsesAPIStreamEvent::ResponseCompleted { .. } => "response.completed", @@ -1365,10 +1364,10 @@ impl crate::providers::streaming_response::ProviderStreamResponse for ResponsesA fn role(&self) -> Option<&str> { match self { - ResponsesAPIStreamEvent::ResponseOutputItemDone { item, .. } => match item { - OutputItem::Message { role, .. } => Some(role.as_str()), - _ => None, - }, + ResponsesAPIStreamEvent::ResponseOutputItemDone { + item: OutputItem::Message { role, .. }, + .. + } => Some(role.as_str()), _ => None, } } diff --git a/crates/hermesllm/src/apis/streaming_shapes/amazon_bedrock_binary_frame.rs b/crates/hermesllm/src/apis/streaming_shapes/amazon_bedrock_binary_frame.rs index 7f68bb26..5156cd52 100644 --- a/crates/hermesllm/src/apis/streaming_shapes/amazon_bedrock_binary_frame.rs +++ b/crates/hermesllm/src/apis/streaming_shapes/amazon_bedrock_binary_frame.rs @@ -34,10 +34,7 @@ where } pub fn decode_frame(&mut self) -> Option { - match self.decoder.decode_frame(&mut self.buffer) { - Ok(frame) => Some(frame), - Err(_e) => None, // Fatal decode error - } + self.decoder.decode_frame(&mut self.buffer).ok() } pub fn buffer_mut(&mut self) -> &mut B { diff --git a/crates/hermesllm/src/apis/streaming_shapes/anthropic_streaming_buffer.rs b/crates/hermesllm/src/apis/streaming_shapes/anthropic_streaming_buffer.rs index 818ee37d..eb9ec5b1 100644 --- a/crates/hermesllm/src/apis/streaming_shapes/anthropic_streaming_buffer.rs +++ b/crates/hermesllm/src/apis/streaming_shapes/anthropic_streaming_buffer.rs @@ -1,5 +1,5 @@ -use crate::apis::streaming_shapes::sse::{SseEvent, SseStreamBufferTrait}; use crate::apis::anthropic::MessagesStreamEvent; +use crate::apis::streaming_shapes::sse::{SseEvent, SseStreamBufferTrait}; use crate::providers::streaming_response::ProviderStreamResponseType; use std::collections::HashSet; @@ -31,6 +31,12 @@ pub struct AnthropicMessagesStreamBuffer { model: Option, } +impl Default for AnthropicMessagesStreamBuffer { + fn default() -> Self { + Self::new() + } +} + impl AnthropicMessagesStreamBuffer { pub fn new() -> Self { Self { @@ -154,7 +160,8 @@ impl SseStreamBufferTrait for AnthropicMessagesStreamBuffer { // Inject message_start if needed if !self.message_started { let model = self.model.as_deref().unwrap_or("unknown"); - let message_start = AnthropicMessagesStreamBuffer::create_message_start_event(model); + let message_start = + AnthropicMessagesStreamBuffer::create_message_start_event(model); self.buffered_events.push(message_start); self.message_started = true; } @@ -169,7 +176,8 @@ impl SseStreamBufferTrait for AnthropicMessagesStreamBuffer { // Inject message_start if needed if !self.message_started { let model = self.model.as_deref().unwrap_or("unknown"); - let message_start = AnthropicMessagesStreamBuffer::create_message_start_event(model); + let message_start = + AnthropicMessagesStreamBuffer::create_message_start_event(model); self.buffered_events.push(message_start); self.message_started = true; } @@ -177,7 +185,8 @@ impl SseStreamBufferTrait for AnthropicMessagesStreamBuffer { // Check if ContentBlockStart was sent for this index if !self.has_content_block_start_been_sent(index) { // Inject ContentBlockStart before delta - let content_block_start = AnthropicMessagesStreamBuffer::create_content_block_start_event(); + let content_block_start = + AnthropicMessagesStreamBuffer::create_content_block_start_event(); self.buffered_events.push(content_block_start); self.set_content_block_start_sent(index); self.needs_content_block_stop = true; @@ -189,7 +198,8 @@ impl SseStreamBufferTrait for AnthropicMessagesStreamBuffer { MessagesStreamEvent::MessageDelta { usage, .. } => { // Inject ContentBlockStop before message_delta if self.needs_content_block_stop { - let content_block_stop = AnthropicMessagesStreamBuffer::create_content_block_stop_event(); + let content_block_stop = + AnthropicMessagesStreamBuffer::create_content_block_stop_event(); self.buffered_events.push(content_block_stop); self.needs_content_block_stop = false; } @@ -199,10 +209,10 @@ impl SseStreamBufferTrait for AnthropicMessagesStreamBuffer { if let Some(last_event) = self.buffered_events.last_mut() { if let Some(ProviderStreamResponseType::MessagesStreamEvent( MessagesStreamEvent::MessageDelta { - usage: last_usage, - .. - } - )) = &mut last_event.provider_stream_response { + usage: last_usage, .. + }, + )) = &mut last_event.provider_stream_response + { // Merge: take stop_reason from first, usage from second (if non-zero) if usage.input_tokens > 0 || usage.output_tokens > 0 { *last_usage = usage.clone(); @@ -243,7 +253,7 @@ impl SseStreamBufferTrait for AnthropicMessagesStreamBuffer { } } - fn into_bytes(&mut self) -> Vec { + fn to_bytes(&mut self) -> Vec { // Convert all accumulated events to bytes and clear buffer // NOTE: We do NOT inject ContentBlockStop here because it's injected when we see MessageDelta // or MessageStop. Injecting it here causes premature ContentBlockStop in the middle of streaming. @@ -276,10 +286,10 @@ impl SseStreamBufferTrait for AnthropicMessagesStreamBuffer { #[cfg(test)] mod tests { use super::*; - use crate::clients::{SupportedAPIsFromClient, SupportedUpstreamAPIs}; use crate::apis::anthropic::AnthropicApi; use crate::apis::openai::OpenAIApi; use crate::apis::streaming_shapes::sse::SseStreamIter; + use crate::clients::{SupportedAPIsFromClient, SupportedUpstreamAPIs}; #[test] fn test_openai_to_anthropic_complete_transformation() { @@ -308,11 +318,12 @@ data: [DONE]"#; let mut buffer = AnthropicMessagesStreamBuffer::new(); for raw_event in stream_iter { - let transformed_event = SseEvent::try_from((raw_event, &client_api, &upstream_api)).unwrap(); + let transformed_event = + SseEvent::try_from((raw_event, &client_api, &upstream_api)).unwrap(); buffer.add_transformed_event(transformed_event); } - let output_bytes = buffer.into_bytes(); + let output_bytes = buffer.to_bytes(); let output = String::from_utf8_lossy(&output_bytes); println!("\nTRANSFORMED OUTPUT (Anthropic Messages API):"); @@ -321,25 +332,54 @@ data: [DONE]"#; // Assertions assert!(!output_bytes.is_empty(), "Should have output"); - assert!(output.contains("event: message_start"), "Should have message_start"); - assert!(output.contains("event: content_block_start"), "Should have content_block_start (injected)"); + assert!( + output.contains("event: message_start"), + "Should have message_start" + ); + assert!( + output.contains("event: content_block_start"), + "Should have content_block_start (injected)" + ); let delta_count = output.matches("event: content_block_delta").count(); - assert_eq!(delta_count, 2, "Should have exactly 2 content_block_delta events"); + assert_eq!( + delta_count, 2, + "Should have exactly 2 content_block_delta events" + ); // Verify both pieces of content are present - assert!(output.contains("\"text\":\"Hello\""), "Should have first content delta 'Hello'"); - assert!(output.contains("\"text\":\" world\""), "Should have second content delta ' world'"); + assert!( + output.contains("\"text\":\"Hello\""), + "Should have first content delta 'Hello'" + ); + assert!( + output.contains("\"text\":\" world\""), + "Should have second content delta ' world'" + ); - assert!(output.contains("event: content_block_stop"), "Should have content_block_stop (injected)"); - assert!(output.contains("event: message_delta"), "Should have message_delta"); - assert!(output.contains("event: message_stop"), "Should have message_stop"); + assert!( + output.contains("event: content_block_stop"), + "Should have content_block_stop (injected)" + ); + assert!( + output.contains("event: message_delta"), + "Should have message_delta" + ); + assert!( + output.contains("event: message_stop"), + "Should have message_stop" + ); println!("\nVALIDATION SUMMARY:"); println!("{}", "-".repeat(80)); println!("✓ Complete transformation: OpenAI ChatCompletions → Anthropic Messages API"); - println!("✓ Injected lifecycle events: message_start, content_block_start, content_block_stop"); - println!("✓ Content deltas: {} events (BOTH 'Hello' and ' world' preserved!)", delta_count); + println!( + "✓ Injected lifecycle events: message_start, content_block_start, content_block_stop" + ); + println!( + "✓ Content deltas: {} events (BOTH 'Hello' and ' world' preserved!)", + delta_count + ); println!("✓ Complete stream with message_stop"); println!("✓ Proper Anthropic protocol sequencing\n"); } @@ -369,11 +409,12 @@ data: {"id":"chatcmpl-456","object":"chat.completion.chunk","created":1234567890 let mut buffer = AnthropicMessagesStreamBuffer::new(); for raw_event in stream_iter { - let transformed_event = SseEvent::try_from((raw_event, &client_api, &upstream_api)).unwrap(); + let transformed_event = + SseEvent::try_from((raw_event, &client_api, &upstream_api)).unwrap(); buffer.add_transformed_event(transformed_event); } - let output_bytes = buffer.into_bytes(); + let output_bytes = buffer.to_bytes(); let output = String::from_utf8_lossy(&output_bytes); println!("\nTRANSFORMED OUTPUT (Anthropic Messages API):"); @@ -382,31 +423,61 @@ data: {"id":"chatcmpl-456","object":"chat.completion.chunk","created":1234567890 // Assertions assert!(!output_bytes.is_empty(), "Should have output"); - assert!(output.contains("event: message_start"), "Should have message_start"); - assert!(output.contains("event: content_block_start"), "Should have content_block_start (injected)"); + assert!( + output.contains("event: message_start"), + "Should have message_start" + ); + assert!( + output.contains("event: content_block_start"), + "Should have content_block_start (injected)" + ); let delta_count = output.matches("event: content_block_delta").count(); - assert_eq!(delta_count, 3, "Should have exactly 3 content_block_delta events"); + assert_eq!( + delta_count, 3, + "Should have exactly 3 content_block_delta events" + ); // Verify all three pieces of content are present - assert!(output.contains("\"text\":\"The weather\""), "Should have first content delta"); - assert!(output.contains("\"text\":\" in San Francisco\""), "Should have second content delta"); - assert!(output.contains("\"text\":\" is\""), "Should have third content delta"); + assert!( + output.contains("\"text\":\"The weather\""), + "Should have first content delta" + ); + assert!( + output.contains("\"text\":\" in San Francisco\""), + "Should have second content delta" + ); + assert!( + output.contains("\"text\":\" is\""), + "Should have third content delta" + ); // For partial streams (no finish_reason, no [DONE]), we do NOT inject content_block_stop // because the stream may continue. This is correct behavior - only inject lifecycle events // when we have explicit signals from upstream (finish_reason, [DONE], etc.) - assert!(!output.contains("event: content_block_stop"), "Should NOT have content_block_stop for partial stream"); + assert!( + !output.contains("event: content_block_stop"), + "Should NOT have content_block_stop for partial stream" + ); // Should NOT have completion events - assert!(!output.contains("event: message_delta"), "Should NOT have message_delta"); - assert!(!output.contains("event: message_stop"), "Should NOT have message_stop"); + assert!( + !output.contains("event: message_delta"), + "Should NOT have message_delta" + ); + assert!( + !output.contains("event: message_stop"), + "Should NOT have message_stop" + ); println!("\nVALIDATION SUMMARY:"); println!("{}", "-".repeat(80)); println!("✓ Partial transformation: OpenAI → Anthropic (stream interrupted)"); println!("✓ Injected: message_start, content_block_start at beginning"); - println!("✓ Incremental deltas: {} events (ALL content preserved!)", delta_count); + println!( + "✓ Incremental deltas: {} events (ALL content preserved!)", + delta_count + ); println!("✓ NO completion events (partial stream, no [DONE])"); println!("✓ Buffer maintains Anthropic protocol for active streams\n"); } @@ -452,11 +523,12 @@ data: [DONE]"#; let mut buffer = AnthropicMessagesStreamBuffer::new(); for raw_event in stream_iter { - let transformed_event = SseEvent::try_from((raw_event, &client_api, &upstream_api)).unwrap(); + let transformed_event = + SseEvent::try_from((raw_event, &client_api, &upstream_api)).unwrap(); buffer.add_transformed_event(transformed_event); } - let output_bytes = buffer.into_bytes(); + let output_bytes = buffer.to_bytes(); let output = String::from_utf8_lossy(&output_bytes); println!("\nTRANSFORMED OUTPUT (Anthropic Messages API):"); @@ -467,32 +539,71 @@ data: [DONE]"#; assert!(!output_bytes.is_empty(), "Should have output"); // Should have lifecycle events (injected by buffer) - assert!(output.contains("event: message_start"), "Should have message_start (injected)"); - assert!(output.contains("event: content_block_start"), "Should have content_block_start"); - assert!(output.contains("event: content_block_stop"), "Should have content_block_stop (injected)"); - assert!(output.contains("event: message_delta"), "Should have message_delta"); - assert!(output.contains("event: message_stop"), "Should have message_stop"); + assert!( + output.contains("event: message_start"), + "Should have message_start (injected)" + ); + assert!( + output.contains("event: content_block_start"), + "Should have content_block_start" + ); + assert!( + output.contains("event: content_block_stop"), + "Should have content_block_stop (injected)" + ); + assert!( + output.contains("event: message_delta"), + "Should have message_delta" + ); + assert!( + output.contains("event: message_stop"), + "Should have message_stop" + ); // Should have tool_use content block - assert!(output.contains("\"type\":\"tool_use\""), "Should have tool_use type"); - assert!(output.contains("\"name\":\"get_weather\""), "Should have correct function name"); - assert!(output.contains("\"id\":\"call_2Uzw0AEZQeOex2CP2TKjcLKc\""), "Should have correct tool call ID"); + assert!( + output.contains("\"type\":\"tool_use\""), + "Should have tool_use type" + ); + assert!( + output.contains("\"name\":\"get_weather\""), + "Should have correct function name" + ); + assert!( + output.contains("\"id\":\"call_2Uzw0AEZQeOex2CP2TKjcLKc\""), + "Should have correct tool call ID" + ); // Count input_json_delta events - should match the number of argument chunks let delta_count = output.matches("event: content_block_delta").count(); - assert!(delta_count >= 8, "Should have at least 8 input_json_delta events"); + assert!( + delta_count >= 8, + "Should have at least 8 input_json_delta events" + ); // Verify argument deltas are present - assert!(output.contains("\"type\":\"input_json_delta\""), "Should have input_json_delta type"); - assert!(output.contains("\"partial_json\":"), "Should have partial_json field"); + assert!( + output.contains("\"type\":\"input_json_delta\""), + "Should have input_json_delta type" + ); + assert!( + output.contains("\"partial_json\":"), + "Should have partial_json field" + ); // Verify the accumulated arguments contain the location assert!(output.contains("San"), "Arguments should contain 'San'"); - assert!(output.contains("Francisco"), "Arguments should contain 'Francisco'"); + assert!( + output.contains("Francisco"), + "Arguments should contain 'Francisco'" + ); assert!(output.contains("CA"), "Arguments should contain 'CA'"); // Verify stop reason is tool_use - assert!(output.contains("\"stop_reason\":\"tool_use\""), "Should have stop_reason as tool_use"); + assert!( + output.contains("\"stop_reason\":\"tool_use\""), + "Should have stop_reason as tool_use" + ); println!("\nVALIDATION SUMMARY:"); println!("{}", "-".repeat(80)); diff --git a/crates/hermesllm/src/apis/streaming_shapes/chat_completions_streaming_buffer.rs b/crates/hermesllm/src/apis/streaming_shapes/chat_completions_streaming_buffer.rs index 0243a5cd..78950717 100644 --- a/crates/hermesllm/src/apis/streaming_shapes/chat_completions_streaming_buffer.rs +++ b/crates/hermesllm/src/apis/streaming_shapes/chat_completions_streaming_buffer.rs @@ -6,6 +6,12 @@ pub struct OpenAIChatCompletionsStreamBuffer { buffered_events: Vec, } +impl Default for OpenAIChatCompletionsStreamBuffer { + fn default() -> Self { + Self::new() + } +} + impl OpenAIChatCompletionsStreamBuffer { pub fn new() -> Self { Self { @@ -26,7 +32,7 @@ impl SseStreamBufferTrait for OpenAIChatCompletionsStreamBuffer { self.buffered_events.push(event); } - fn into_bytes(&mut self) -> Vec { + fn to_bytes(&mut self) -> Vec { // No finalization needed for OpenAI Chat Completions // The [DONE] marker is already handled by the transformation layer let mut buffer = Vec::new(); diff --git a/crates/hermesllm/src/apis/streaming_shapes/mod.rs b/crates/hermesllm/src/apis/streaming_shapes/mod.rs index 4db3b094..b1118889 100644 --- a/crates/hermesllm/src/apis/streaming_shapes/mod.rs +++ b/crates/hermesllm/src/apis/streaming_shapes/mod.rs @@ -1,7 +1,7 @@ -pub mod sse; -pub mod sse_chunk_processor; pub mod amazon_bedrock_binary_frame; pub mod anthropic_streaming_buffer; pub mod chat_completions_streaming_buffer; pub mod passthrough_streaming_buffer; pub mod responses_api_streaming_buffer; +pub mod sse; +pub mod sse_chunk_processor; diff --git a/crates/hermesllm/src/apis/streaming_shapes/passthrough_streaming_buffer.rs b/crates/hermesllm/src/apis/streaming_shapes/passthrough_streaming_buffer.rs index 2ac2a688..53ed7620 100644 --- a/crates/hermesllm/src/apis/streaming_shapes/passthrough_streaming_buffer.rs +++ b/crates/hermesllm/src/apis/streaming_shapes/passthrough_streaming_buffer.rs @@ -6,6 +6,12 @@ pub struct PassthroughStreamBuffer { buffered_events: Vec, } +impl Default for PassthroughStreamBuffer { + fn default() -> Self { + Self::new() + } +} + impl PassthroughStreamBuffer { pub fn new() -> Self { Self { @@ -30,7 +36,7 @@ impl SseStreamBufferTrait for PassthroughStreamBuffer { self.buffered_events.push(event); } - fn into_bytes(&mut self) -> Vec { + fn to_bytes(&mut self) -> Vec { // No finalization needed for passthrough - just convert accumulated events to bytes let mut buffer = Vec::new(); for event in self.buffered_events.drain(..) { @@ -44,7 +50,7 @@ impl SseStreamBufferTrait for PassthroughStreamBuffer { #[cfg(test)] mod tests { use crate::apis::streaming_shapes::passthrough_streaming_buffer::PassthroughStreamBuffer; - use crate::apis::streaming_shapes::sse::{SseStreamIter, SseStreamBufferTrait}; + use crate::apis::streaming_shapes::sse::{SseStreamBufferTrait, SseStreamIter}; #[test] fn test_chat_completions_passthrough_buffer() { @@ -73,7 +79,7 @@ mod tests { buffer.add_transformed_event(event); } - let output_bytes = buffer.into_bytes(); + let output_bytes = buffer.to_bytes(); let output = String::from_utf8_lossy(&output_bytes); println!("\nTRANSFORMED OUTPUT (ChatCompletions - Passthrough):"); @@ -84,7 +90,11 @@ mod tests { assert!(!output_bytes.is_empty()); assert!(output.contains("chatcmpl-123")); assert!(output.contains("[DONE]")); - assert_eq!(raw_input.trim(), output.trim(), "Passthrough should preserve input"); + assert_eq!( + raw_input.trim(), + output.trim(), + "Passthrough should preserve input" + ); println!("\nVALIDATION SUMMARY:"); println!("{}", "-".repeat(80)); diff --git a/crates/hermesllm/src/apis/streaming_shapes/responses_api_streaming_buffer.rs b/crates/hermesllm/src/apis/streaming_shapes/responses_api_streaming_buffer.rs index ca8a9cfd..2aeb34ac 100644 --- a/crates/hermesllm/src/apis/streaming_shapes/responses_api_streaming_buffer.rs +++ b/crates/hermesllm/src/apis/streaming_shapes/responses_api_streaming_buffer.rs @@ -1,10 +1,10 @@ -use std::collections::HashMap; -use log::debug; use crate::apis::openai_responses::{ - ResponsesAPIStreamEvent, ResponsesAPIResponse, OutputItem, OutputItemStatus, - ResponseStatus, TextConfig, TextFormat, Reasoning, + OutputItem, OutputItemStatus, Reasoning, ResponseStatus, ResponsesAPIResponse, + ResponsesAPIStreamEvent, TextConfig, TextFormat, }; use crate::apis::streaming_shapes::sse::{SseEvent, SseStreamBufferTrait}; +use log::debug; +use std::collections::HashMap; /// Helper to convert ResponseAPIStreamEvent to SseEvent fn event_to_sse(event: ResponsesAPIStreamEvent) -> SseEvent { @@ -16,10 +16,17 @@ fn event_to_sse(event: ResponsesAPIStreamEvent) -> SseEvent { ResponsesAPIStreamEvent::ResponseOutputItemDone { .. } => "response.output_item.done", ResponsesAPIStreamEvent::ResponseOutputTextDelta { .. } => "response.output_text.delta", ResponsesAPIStreamEvent::ResponseOutputTextDone { .. } => "response.output_text.done", - ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDelta { .. } => "response.function_call_arguments.delta", - ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDone { .. } => "response.function_call_arguments.done", + ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDelta { .. } => { + "response.function_call_arguments.delta" + } + ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDone { .. } => { + "response.function_call_arguments.done" + } unknown => { - debug!("Unknown ResponsesAPIStreamEvent type encountered: {:?}", unknown); + debug!( + "Unknown ResponsesAPIStreamEvent type encountered: {:?}", + unknown + ); "unknown" } }; @@ -85,6 +92,12 @@ pub struct ResponsesAPIStreamBuffer { buffered_events: Vec, } +impl Default for ResponsesAPIStreamBuffer { + fn default() -> Self { + Self::new() + } +} + impl ResponsesAPIStreamBuffer { pub fn new() -> Self { Self { @@ -112,7 +125,11 @@ impl ResponsesAPIStreamBuffer { } fn generate_item_id(prefix: &str) -> String { - format!("{}_{}", prefix, uuid::Uuid::new_v4().to_string().replace("-", "")) + format!( + "{}_{}", + prefix, + uuid::Uuid::new_v4().to_string().replace("-", "") + ) } fn get_or_create_item_id(&mut self, output_index: i32, prefix: &str) -> String { @@ -160,7 +177,13 @@ impl ResponsesAPIStreamBuffer { } /// Create output_item.added event for tool call - fn create_tool_call_added_event(&mut self, output_index: i32, item_id: &str, call_id: &str, name: &str) -> SseEvent { + fn create_tool_call_added_event( + &mut self, + output_index: i32, + item_id: &str, + call_id: &str, + name: &str, + ) -> SseEvent { let event = ResponsesAPIStreamEvent::ResponseOutputItemAdded { output_index, item: OutputItem::FunctionCall { @@ -237,9 +260,15 @@ impl ResponsesAPIStreamBuffer { // Emit done events for all accumulated content // Text content done events - let text_items: Vec<_> = self.text_content.iter().map(|(id, content)| (id.clone(), content.clone())).collect(); + let text_items: Vec<_> = self + .text_content + .iter() + .map(|(id, content)| (id.clone(), content.clone())) + .collect(); for (item_id, content) in text_items { - let output_index = self.output_items_added.iter() + let output_index = self + .output_items_added + .iter() .find(|(_, id)| **id == item_id) .map(|(idx, _)| *idx) .unwrap_or(0); @@ -270,9 +299,15 @@ impl ResponsesAPIStreamBuffer { } // Function call done events - let func_items: Vec<_> = self.function_arguments.iter().map(|(id, args)| (id.clone(), args.clone())).collect(); + let func_items: Vec<_> = self + .function_arguments + .iter() + .map(|(id, args)| (id.clone(), args.clone())) + .collect(); for (item_id, arguments) in func_items { - let output_index = self.output_items_added.iter() + let output_index = self + .output_items_added + .iter() .find(|(_, id)| **id == item_id) .map(|(idx, _)| *idx) .unwrap_or(0); @@ -286,9 +321,16 @@ impl ResponsesAPIStreamBuffer { }; events.push(event_to_sse(args_done_event)); - let (call_id, name) = self.tool_call_metadata.get(&output_index) + let (call_id, name) = self + .tool_call_metadata + .get(&output_index) .cloned() - .unwrap_or_else(|| (format!("call_{}", uuid::Uuid::new_v4()), "unknown".to_string())); + .unwrap_or_else(|| { + ( + format!("call_{}", uuid::Uuid::new_v4()), + "unknown".to_string(), + ) + }); let seq2 = self.next_sequence_number(); let item_done_event = ResponsesAPIStreamEvent::ResponseOutputItemDone { @@ -315,9 +357,16 @@ impl ResponsesAPIStreamBuffer { if let Some(item_id) = self.output_items_added.get(&output_index) { // Check if this is a function call if let Some(arguments) = self.function_arguments.get(item_id) { - let (call_id, name) = self.tool_call_metadata.get(&output_index) + let (call_id, name) = self + .tool_call_metadata + .get(&output_index) .cloned() - .unwrap_or_else(|| (format!("call_{}", uuid::Uuid::new_v4()), "unknown".to_string())); + .unwrap_or_else(|| { + ( + format!("call_{}", uuid::Uuid::new_v4()), + "unknown".to_string(), + ) + }); output_items.push(OutputItem::FunctionCall { id: item_id.clone(), @@ -397,9 +446,9 @@ impl SseStreamBufferTrait for ResponsesAPIStreamBuffer { let mut events = Vec::new(); // Capture upstream metadata from ResponseCreated or ResponseInProgress if present - match stream_event { - ResponsesAPIStreamEvent::ResponseCreated { response, .. } | - ResponsesAPIStreamEvent::ResponseInProgress { response, .. } => { + match stream_event.as_ref() { + ResponsesAPIStreamEvent::ResponseCreated { response, .. } + | ResponsesAPIStreamEvent::ResponseInProgress { response, .. } => { if self.upstream_response_metadata.is_none() { // Store the full upstream response as our metadata template self.upstream_response_metadata = Some(response.clone()); @@ -418,11 +467,16 @@ impl SseStreamBufferTrait for ResponsesAPIStreamBuffer { if !self.created_emitted { // Initialize metadata from first event if needed if self.response_id.is_none() { - self.response_id = Some(format!("resp_{}", uuid::Uuid::new_v4().to_string().replace("-", ""))); - self.created_at = Some(std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap() - .as_secs() as i64); + self.response_id = Some(format!( + "resp_{}", + uuid::Uuid::new_v4().to_string().replace("-", "") + )); + self.created_at = Some( + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_secs() as i64, + ); self.model = Some("unknown".to_string()); // Will be set by caller if available } @@ -436,58 +490,95 @@ impl SseStreamBufferTrait for ResponsesAPIStreamBuffer { } // Process the delta event - match stream_event { - ResponsesAPIStreamEvent::ResponseOutputTextDelta { output_index, delta, .. } => { + match stream_event.as_ref() { + ResponsesAPIStreamEvent::ResponseOutputTextDelta { + output_index, + delta, + .. + } => { let item_id = self.get_or_create_item_id(*output_index, "msg"); // Emit output_item.added if this is the first time we see this output index if !self.output_items_added.contains_key(output_index) { - self.output_items_added.insert(*output_index, item_id.clone()); + self.output_items_added + .insert(*output_index, item_id.clone()); events.push(self.create_output_item_added_event(*output_index, &item_id)); } // Accumulate text content - self.text_content.entry(item_id.clone()) + self.text_content + .entry(item_id.clone()) .and_modify(|content| content.push_str(delta)) .or_insert_with(|| delta.clone()); // Emit text delta with filled-in item_id and sequence_number - let mut delta_event = stream_event.clone(); - if let ResponsesAPIStreamEvent::ResponseOutputTextDelta { item_id: ref mut id, sequence_number: ref mut seq, .. } = delta_event { + let mut delta_event = stream_event.as_ref().clone(); + if let ResponsesAPIStreamEvent::ResponseOutputTextDelta { + item_id: ref mut id, + sequence_number: ref mut seq, + .. + } = &mut delta_event + { *id = item_id; *seq = self.next_sequence_number(); } events.push(event_to_sse(delta_event)); } - ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDelta { output_index, delta, call_id, name, .. } => { + ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDelta { + output_index, + delta, + call_id, + name, + .. + } => { let item_id = self.get_or_create_item_id(*output_index, "fc"); // Store metadata if provided (from initial tool call event) if let (Some(cid), Some(n)) = (call_id, name) { - self.tool_call_metadata.insert(*output_index, (cid.clone(), n.clone())); + self.tool_call_metadata + .insert(*output_index, (cid.clone(), n.clone())); } // Emit output_item.added if this is the first time we see this tool call if !self.output_items_added.contains_key(output_index) { - self.output_items_added.insert(*output_index, item_id.clone()); + self.output_items_added + .insert(*output_index, item_id.clone()); // For tool calls, we need call_id and name from metadata // These should now be populated from the event itself - let (call_id, name) = self.tool_call_metadata.get(output_index) + let (call_id, name) = self + .tool_call_metadata + .get(output_index) .cloned() - .unwrap_or_else(|| (format!("call_{}", uuid::Uuid::new_v4()), "unknown".to_string())); + .unwrap_or_else(|| { + ( + format!("call_{}", uuid::Uuid::new_v4()), + "unknown".to_string(), + ) + }); - events.push(self.create_tool_call_added_event(*output_index, &item_id, &call_id, &name)); + events.push(self.create_tool_call_added_event( + *output_index, + &item_id, + &call_id, + &name, + )); } // Accumulate function arguments - self.function_arguments.entry(item_id.clone()) + self.function_arguments + .entry(item_id.clone()) .and_modify(|args| args.push_str(delta)) .or_insert_with(|| delta.clone()); // Emit function call arguments delta with filled-in item_id and sequence_number - let mut delta_event = stream_event.clone(); - if let ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDelta { item_id: ref mut id, sequence_number: ref mut seq, .. } = delta_event { + let mut delta_event = stream_event.as_ref().clone(); + if let ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDelta { + item_id: ref mut id, + sequence_number: ref mut seq, + .. + } = &mut delta_event + { *id = item_id; *seq = self.next_sequence_number(); } @@ -495,7 +586,7 @@ impl SseStreamBufferTrait for ResponsesAPIStreamBuffer { } _ => { // For other event types, just pass through with sequence number - let other_event = stream_event.clone(); + let other_event = stream_event.as_ref().clone(); // TODO: Add sequence number to other event types if needed events.push(event_to_sse(other_event)); } @@ -505,8 +596,7 @@ impl SseStreamBufferTrait for ResponsesAPIStreamBuffer { self.buffered_events.extend(events); } - - fn into_bytes(&mut self) -> Vec { + fn to_bytes(&mut self) -> Vec { // For Responses API, we need special handling: // - Most events are already in buffered_events from add_transformed_event // - We should NOT finalize here - finalization happens when we detect [DONE] or end of stream @@ -525,9 +615,9 @@ impl SseStreamBufferTrait for ResponsesAPIStreamBuffer { #[cfg(test)] mod tests { use super::*; - use crate::clients::{SupportedAPIsFromClient, SupportedUpstreamAPIs}; use crate::apis::openai::OpenAIApi; use crate::apis::streaming_shapes::sse::SseStreamIter; + use crate::clients::{SupportedAPIsFromClient, SupportedUpstreamAPIs}; #[test] fn test_chat_completions_to_responses_api_transformation() { @@ -557,11 +647,12 @@ mod tests { for raw_event in stream_iter { // Transform the event using the client/upstream APIs - let transformed_event = SseEvent::try_from((raw_event, &client_api, &upstream_api)).unwrap(); + let transformed_event = + SseEvent::try_from((raw_event, &client_api, &upstream_api)).unwrap(); buffer.add_transformed_event(transformed_event); } - let output_bytes = buffer.into_bytes(); + let output_bytes = buffer.to_bytes(); let output = String::from_utf8_lossy(&output_bytes); println!("\nTRANSFORMED OUTPUT (ResponsesAPI):"); @@ -570,13 +661,34 @@ mod tests { // Assertions assert!(!output_bytes.is_empty(), "Should have output"); - assert!(output.contains("response.created"), "Should have response.created"); - assert!(output.contains("response.in_progress"), "Should have response.in_progress"); - assert!(output.contains("response.output_item.added"), "Should have output_item.added"); - assert!(output.contains("response.output_text.delta"), "Should have text deltas"); - assert!(output.contains("response.output_text.done"), "Should have text.done"); - assert!(output.contains("response.output_item.done"), "Should have output_item.done"); - assert!(output.contains("response.completed"), "Should have response.completed"); + assert!( + output.contains("response.created"), + "Should have response.created" + ); + assert!( + output.contains("response.in_progress"), + "Should have response.in_progress" + ); + assert!( + output.contains("response.output_item.added"), + "Should have output_item.added" + ); + assert!( + output.contains("response.output_text.delta"), + "Should have text deltas" + ); + assert!( + output.contains("response.output_text.done"), + "Should have text.done" + ); + assert!( + output.contains("response.output_item.done"), + "Should have output_item.done" + ); + assert!( + output.contains("response.completed"), + "Should have response.completed" + ); println!("\nVALIDATION SUMMARY:"); println!("{}", "-".repeat(80)); @@ -616,7 +728,7 @@ mod tests { buffer.add_transformed_event(transformed); } - let output_bytes = buffer.into_bytes(); + let output_bytes = buffer.to_bytes(); let output = String::from_utf8_lossy(&output_bytes); println!("\nTRANSFORMED OUTPUT (ResponsesAPI):"); @@ -624,24 +736,55 @@ mod tests { println!("{}", output); // Assertions - assert!(output.contains("response.created"), "Should have response.created"); - assert!(output.contains("response.in_progress"), "Should have response.in_progress"); - assert!(output.contains("response.output_item.added"), "Should have output_item.added"); - assert!(output.contains("\"type\":\"function_call\""), "Should be function_call type"); - assert!(output.contains("\"name\":\"get_weather\""), "Should have function name"); - assert!(output.contains("\"call_id\":\"call_mD5ggLKk3SMKGPFqFdcpKg6q\""), "Should have correct call_id"); + assert!( + output.contains("response.created"), + "Should have response.created" + ); + assert!( + output.contains("response.in_progress"), + "Should have response.in_progress" + ); + assert!( + output.contains("response.output_item.added"), + "Should have output_item.added" + ); + assert!( + output.contains("\"type\":\"function_call\""), + "Should be function_call type" + ); + assert!( + output.contains("\"name\":\"get_weather\""), + "Should have function name" + ); + assert!( + output.contains("\"call_id\":\"call_mD5ggLKk3SMKGPFqFdcpKg6q\""), + "Should have correct call_id" + ); - let delta_count = output.matches("event: response.function_call_arguments.delta").count(); + let delta_count = output + .matches("event: response.function_call_arguments.delta") + .count(); assert_eq!(delta_count, 4, "Should have 4 delta events"); - assert!(!output.contains("response.function_call_arguments.done"), "Should NOT have arguments.done"); - assert!(!output.contains("response.output_item.done"), "Should NOT have output_item.done"); - assert!(!output.contains("response.completed"), "Should NOT have response.completed"); + assert!( + !output.contains("response.function_call_arguments.done"), + "Should NOT have arguments.done" + ); + assert!( + !output.contains("response.output_item.done"), + "Should NOT have output_item.done" + ); + assert!( + !output.contains("response.completed"), + "Should NOT have response.completed" + ); println!("\nVALIDATION SUMMARY:"); println!("{}", "-".repeat(80)); println!("✓ Lifecycle events: response.created, response.in_progress"); - println!("✓ Function call metadata: name='get_weather', call_id='call_mD5ggLKk3SMKGPFqFdcpKg6q'"); + println!( + "✓ Function call metadata: name='get_weather', call_id='call_mD5ggLKk3SMKGPFqFdcpKg6q'" + ); println!("✓ Incremental deltas: 4 events (1 initial + 3 argument chunks)"); println!("✓ NO completion events (partial stream, no [DONE])"); println!("✓ Arguments accumulated: '{{\"location\":\"'\n"); diff --git a/crates/hermesllm/src/apis/streaming_shapes/sse.rs b/crates/hermesllm/src/apis/streaming_shapes/sse.rs index 05f0b296..f0d31e0a 100644 --- a/crates/hermesllm/src/apis/streaming_shapes/sse.rs +++ b/crates/hermesllm/src/apis/streaming_shapes/sse.rs @@ -1,9 +1,9 @@ -use crate::providers::streaming_response::ProviderStreamResponse; -use crate::providers::streaming_response::ProviderStreamResponseType; -use crate::apis::streaming_shapes::chat_completions_streaming_buffer::OpenAIChatCompletionsStreamBuffer; use crate::apis::streaming_shapes::anthropic_streaming_buffer::AnthropicMessagesStreamBuffer; +use crate::apis::streaming_shapes::chat_completions_streaming_buffer::OpenAIChatCompletionsStreamBuffer; use crate::apis::streaming_shapes::passthrough_streaming_buffer::PassthroughStreamBuffer; use crate::apis::streaming_shapes::responses_api_streaming_buffer::ResponsesAPIStreamBuffer; +use crate::providers::streaming_response::ProviderStreamResponse; +use crate::providers::streaming_response::ProviderStreamResponseType; use serde::{Deserialize, Serialize}; use std::error::Error; use std::fmt; @@ -37,7 +37,7 @@ pub trait SseStreamBufferTrait: Send + Sync { /// /// # Returns /// Bytes ready for wire transmission (may be empty if no events were accumulated) - fn into_bytes(&mut self) -> Vec; + fn to_bytes(&mut self) -> Vec; } /// Unified SSE Stream Buffer enum that provides a zero-cost abstraction @@ -45,7 +45,7 @@ pub enum SseStreamBuffer { Passthrough(PassthroughStreamBuffer), OpenAIChatCompletions(OpenAIChatCompletionsStreamBuffer), AnthropicMessages(AnthropicMessagesStreamBuffer), - OpenAIResponses(ResponsesAPIStreamBuffer), + OpenAIResponses(Box), } impl SseStreamBufferTrait for SseStreamBuffer { @@ -58,12 +58,12 @@ impl SseStreamBufferTrait for SseStreamBuffer { } } - fn into_bytes(&mut self) -> Vec { + fn to_bytes(&mut self) -> Vec { match self { - Self::Passthrough(buffer) => buffer.into_bytes(), - Self::OpenAIChatCompletions(buffer) => buffer.into_bytes(), - Self::AnthropicMessages(buffer) => buffer.into_bytes(), - Self::OpenAIResponses(buffer) => buffer.into_bytes(), + Self::Passthrough(buffer) => buffer.to_bytes(), + Self::OpenAIChatCompletions(buffer) => buffer.to_bytes(), + Self::AnthropicMessages(buffer) => buffer.to_bytes(), + Self::OpenAIResponses(buffer) => buffer.to_bytes(), } } } @@ -99,7 +99,7 @@ impl SseEvent { let sse_string: String = response.clone().into(); SseEvent { - data: None, // Data is embedded in sse_transformed_lines + data: None, // Data is embedded in sse_transformed_lines event: None, // Event type is embedded in sse_transformed_lines raw_line: sse_string.clone(), sse_transformed_lines: sse_string, @@ -149,10 +149,8 @@ impl FromStr for SseEvent { }); } - if trimmed_line.starts_with("data: ") { - let data: String = trimmed_line[6..].to_string(); // Remove "data: " prefix - // Allow empty data content after "data: " prefix - // This handles cases like "data: " followed by newline + if let Some(stripped) = trimmed_line.strip_prefix("data: ") { + let data: String = stripped.to_string(); if data.trim().is_empty() { return Err(SseParseError { message: "Empty data field after 'data: ' prefix".to_string(), @@ -166,8 +164,8 @@ impl FromStr for SseEvent { sse_transformed_lines: line.to_string(), provider_stream_response: None, }) - } else if trimmed_line.starts_with("event: ") { - let event_type = trimmed_line[7..].to_string(); + } else if let Some(stripped) = trimmed_line.strip_prefix("event: ") { + let event_type = stripped.to_string(); if event_type.is_empty() { return Err(SseParseError { message: "Empty event field is not a valid SSE event".to_string(), @@ -183,7 +181,10 @@ impl FromStr for SseEvent { }) } else { Err(SseParseError { - message: format!("Line does not start with 'data: ' or 'event: ': {}", trimmed_line), + message: format!( + "Line does not start with 'data: ' or 'event: ': {}", + trimmed_line + ), }) } } @@ -196,16 +197,16 @@ impl fmt::Display for SseEvent { } // Into implementation to convert SseEvent to bytes for response buffer -impl Into> for SseEvent { - fn into(self) -> Vec { +impl From for Vec { + fn from(val: SseEvent) -> Self { // For generated events (like ResponsesAPI), sse_transformed_lines already includes trailing \n\n // For parsed events (like passthrough), we need to add the \n\n separator - if self.sse_transformed_lines.ends_with("\n\n") { + if val.sse_transformed_lines.ends_with("\n\n") { // Already properly formatted with trailing newlines - self.sse_transformed_lines.into_bytes() + val.sse_transformed_lines.into_bytes() } else { // Add SSE event separator - format!("{}\n\n", self.sse_transformed_lines).into_bytes() + format!("{}\n\n", val.sse_transformed_lines).into_bytes() } } } diff --git a/crates/hermesllm/src/apis/streaming_shapes/sse_chunk_processor.rs b/crates/hermesllm/src/apis/streaming_shapes/sse_chunk_processor.rs index c7d25527..64814d65 100644 --- a/crates/hermesllm/src/apis/streaming_shapes/sse_chunk_processor.rs +++ b/crates/hermesllm/src/apis/streaming_shapes/sse_chunk_processor.rs @@ -10,6 +10,12 @@ pub struct SseChunkProcessor { incomplete_event_buffer: Vec, } +impl Default for SseChunkProcessor { + fn default() -> Self { + Self::new() + } +} + impl SseChunkProcessor { pub fn new() -> Self { Self { @@ -93,8 +99,8 @@ impl SseChunkProcessor { #[cfg(test)] mod tests { use super::*; - use crate::clients::endpoints::{SupportedAPIsFromClient, SupportedUpstreamAPIs}; use crate::apis::openai::OpenAIApi; + use crate::clients::endpoints::{SupportedAPIsFromClient, SupportedUpstreamAPIs}; #[test] fn test_complete_events_process_immediately() { @@ -104,7 +110,9 @@ mod tests { let chunk1 = b"data: {\"id\":\"chatcmpl-123\",\"object\":\"chat.completion.chunk\",\"created\":1234567890,\"model\":\"gpt-4o\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"Hello\"},\"finish_reason\":null}]}\n\n"; - let events = processor.process_chunk(chunk1, &client_api, &upstream_api).unwrap(); + let events = processor + .process_chunk(chunk1, &client_api, &upstream_api) + .unwrap(); assert_eq!(events.len(), 1); assert!(!processor.has_buffered_data()); @@ -119,18 +127,28 @@ mod tests { // First chunk with incomplete JSON let chunk1 = b"data: {\"id\":\"chatcmpl-123\",\"object\":\"chat.completion.chu"; - let events1 = processor.process_chunk(chunk1, &client_api, &upstream_api).unwrap(); + let events1 = processor + .process_chunk(chunk1, &client_api, &upstream_api) + .unwrap(); assert_eq!(events1.len(), 0, "Incomplete event should not be processed"); - assert!(processor.has_buffered_data(), "Incomplete data should be buffered"); + assert!( + processor.has_buffered_data(), + "Incomplete data should be buffered" + ); // Second chunk completes the JSON let chunk2 = b"nk\",\"created\":1234567890,\"model\":\"gpt-4o\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"Hello\"},\"finish_reason\":null}]}\n\n"; - let events2 = processor.process_chunk(chunk2, &client_api, &upstream_api).unwrap(); + let events2 = processor + .process_chunk(chunk2, &client_api, &upstream_api) + .unwrap(); assert_eq!(events2.len(), 1, "Complete event should be processed"); - assert!(!processor.has_buffered_data(), "Buffer should be cleared after completion"); + assert!( + !processor.has_buffered_data(), + "Buffer should be cleared after completion" + ); } #[test] @@ -142,10 +160,15 @@ mod tests { // Chunk with 2 complete events and 1 incomplete let chunk = b"data: {\"id\":\"chatcmpl-123\",\"object\":\"chat.completion.chunk\",\"created\":1234567890,\"model\":\"gpt-4o\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"A\"},\"finish_reason\":null}]}\n\ndata: {\"id\":\"chatcmpl-124\",\"object\":\"chat.completion.chunk\",\"created\":1234567890,\"model\":\"gpt-4o\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"B\"},\"finish_reason\":null}]}\n\ndata: {\"id\":\"chatcmpl-125\",\"object\":\"chat.completion.chu"; - let events = processor.process_chunk(chunk, &client_api, &upstream_api).unwrap(); + let events = processor + .process_chunk(chunk, &client_api, &upstream_api) + .unwrap(); assert_eq!(events.len(), 2, "Two complete events should be processed"); - assert!(processor.has_buffered_data(), "Incomplete third event should be buffered"); + assert!( + processor.has_buffered_data(), + "Incomplete third event should be buffered" + ); } #[test] @@ -171,11 +194,23 @@ data: {"type":"content_block_stop","index":0} Ok(events) => { println!("Successfully processed {} events", events.len()); for (i, event) in events.iter().enumerate() { - println!("Event {}: event={:?}, has_data={}", i, event.event, event.data.is_some()); + println!( + "Event {}: event={:?}, has_data={}", + i, + event.event, + event.data.is_some() + ); } // Should successfully process both events (signature_delta + content_block_stop) - assert!(events.len() >= 2, "Should process at least 2 complete events (signature_delta + stop), got {}", events.len()); - assert!(!processor.has_buffered_data(), "Complete events should not be buffered"); + assert!( + events.len() >= 2, + "Should process at least 2 complete events (signature_delta + stop), got {}", + events.len() + ); + assert!( + !processor.has_buffered_data(), + "Complete events should not be buffered" + ); } Err(e) => { panic!("Failed to process signature_delta chunk - this means SignatureDelta is not properly handled: {}", e); @@ -194,12 +229,21 @@ data: {"type":"content_block_stop","index":0} // Second event is valid and should be processed let chunk = b"data: {\"id\":\"chatcmpl-123\",\"object\":\"chat.completion.chunk\",\"created\":1234567890,\"model\":\"gpt-4o\",\"choices\":[{\"index\":0,\"delta\":{\"unsupported_field_causing_validation_error\":true},\"finish_reason\":null}]}\n\ndata: {\"id\":\"chatcmpl-124\",\"object\":\"chat.completion.chunk\",\"created\":1234567890,\"model\":\"gpt-4o\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"Hello\"},\"finish_reason\":null}]}\n\n"; - let events = processor.process_chunk(chunk, &client_api, &upstream_api).unwrap(); + let events = processor + .process_chunk(chunk, &client_api, &upstream_api) + .unwrap(); // Should skip the invalid event and process the valid one // (If we were buffering all errors, we'd get 0 events and have buffered data) - assert!(events.len() >= 1, "Should process at least the valid event, got {} events", events.len()); - assert!(!processor.has_buffered_data(), "Invalid (non-incomplete) events should not be buffered"); + assert!( + !events.is_empty(), + "Should process at least the valid event, got {} events", + events.len() + ); + assert!( + !processor.has_buffered_data(), + "Invalid (non-incomplete) events should not be buffered" + ); } #[test] @@ -227,14 +271,27 @@ data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text match result { Ok(events) => { - println!("Processed {} events (unsupported event should be skipped)", events.len()); + println!( + "Processed {} events (unsupported event should be skipped)", + events.len() + ); // Should process the 2 valid text_delta events and skip the unsupported one // We expect at least 2 events (the valid ones), unsupported should be skipped - assert!(events.len() >= 2, "Should process at least 2 valid events, got {}", events.len()); - assert!(!processor.has_buffered_data(), "Unsupported events should be skipped, not buffered"); + assert!( + events.len() >= 2, + "Should process at least 2 valid events, got {}", + events.len() + ); + assert!( + !processor.has_buffered_data(), + "Unsupported events should be skipped, not buffered" + ); } Err(e) => { - panic!("Should not fail on unsupported delta type, should skip it: {}", e); + panic!( + "Should not fail on unsupported delta type, should skip it: {}", + e + ); } } } diff --git a/crates/hermesllm/src/clients/endpoints.rs b/crates/hermesllm/src/clients/endpoints.rs index 5a923329..eff96cc5 100644 --- a/crates/hermesllm/src/clients/endpoints.rs +++ b/crates/hermesllm/src/clients/endpoints.rs @@ -135,7 +135,10 @@ impl SupportedAPIsFromClient { ProviderId::AzureOpenAI => { if request_path.starts_with("/v1/") { let suffix = endpoint_suffix.trim_start_matches('/'); - build_endpoint("/openai/deployments", &format!("/{}/{}?api-version=2025-01-01-preview", model_id, suffix)) + build_endpoint( + "/openai/deployments", + &format!("/{}/{}?api-version=2025-01-01-preview", model_id, suffix), + ) } else { build_endpoint("/v1", endpoint_suffix) } @@ -163,19 +166,21 @@ impl SupportedAPIsFromClient { }; match self { - SupportedAPIsFromClient::AnthropicMessagesAPI(AnthropicApi::Messages) => match provider_id { - ProviderId::Anthropic => build_endpoint("/v1", "/messages"), - ProviderId::AmazonBedrock => { - if request_path.starts_with("/v1/") && !is_streaming { - build_endpoint("", &format!("/model/{}/converse", model_id)) - } else if request_path.starts_with("/v1/") && is_streaming { - build_endpoint("", &format!("/model/{}/converse-stream", model_id)) - } else { - build_endpoint("/v1", "/chat/completions") + SupportedAPIsFromClient::AnthropicMessagesAPI(AnthropicApi::Messages) => { + match provider_id { + ProviderId::Anthropic => build_endpoint("/v1", "/messages"), + ProviderId::AmazonBedrock => { + if request_path.starts_with("/v1/") && !is_streaming { + build_endpoint("", &format!("/model/{}/converse", model_id)) + } else if request_path.starts_with("/v1/") && is_streaming { + build_endpoint("", &format!("/model/{}/converse-stream", model_id)) + } else { + build_endpoint("/v1", "/chat/completions") + } } + _ => build_endpoint("/v1", "/chat/completions"), } - _ => build_endpoint("/v1", "/chat/completions"), - }, + } SupportedAPIsFromClient::OpenAIResponsesAPI(_) => { // For Responses API, check if provider supports it, otherwise translate to chat/completions match provider_id { @@ -193,7 +198,6 @@ impl SupportedAPIsFromClient { } } - impl SupportedUpstreamAPIs { /// Create a SupportedUpstreamApi from an endpoint path pub fn from_endpoint(endpoint: &str) -> Option { @@ -216,17 +220,17 @@ impl SupportedUpstreamAPIs { return Some(SupportedUpstreamAPIs::AmazonBedrockConverse(bedrock_api)) } AmazonBedrockApi::ConverseStream => { - return Some(SupportedUpstreamAPIs::AmazonBedrockConverseStream(bedrock_api)) + return Some(SupportedUpstreamAPIs::AmazonBedrockConverseStream( + bedrock_api, + )) } } } None } - } - /// Get all supported endpoint paths pub fn supported_endpoints() -> Vec<&'static str> { let mut endpoints = Vec::new(); @@ -269,9 +273,9 @@ mod tests { assert!(SupportedAPIsFromClient::from_endpoint("/v1/messages").is_some()); // Unsupported endpoints - assert!(!SupportedAPIsFromClient::from_endpoint("/v1/unknown").is_some()); - assert!(!SupportedAPIsFromClient::from_endpoint("/v2/chat").is_some()); - assert!(!SupportedAPIsFromClient::from_endpoint("").is_some()); + assert!(SupportedAPIsFromClient::from_endpoint("/v1/unknown").is_none()); + assert!(SupportedAPIsFromClient::from_endpoint("/v2/chat").is_none()); + assert!(SupportedAPIsFromClient::from_endpoint("").is_none()); } #[test] diff --git a/crates/hermesllm/src/lib.rs b/crates/hermesllm/src/lib.rs index 918fd4e9..3f8324e9 100644 --- a/crates/hermesllm/src/lib.rs +++ b/crates/hermesllm/src/lib.rs @@ -12,11 +12,9 @@ pub use aws_smithy_eventstream::frame::DecodedFrame; pub use providers::id::ProviderId; pub use providers::request::{ProviderRequest, ProviderRequestError, ProviderRequestType}; pub use providers::response::{ - ProviderResponse, ProviderResponseType, TokenUsage, ProviderResponseError -}; -pub use providers::streaming_response::{ - ProviderStreamResponse, ProviderStreamResponseType + ProviderResponse, ProviderResponseError, ProviderResponseType, TokenUsage, }; +pub use providers::streaming_response::{ProviderStreamResponse, ProviderStreamResponseType}; //TODO: Refactor such that commons doesn't depend on Hermes. For now this will clean up strings pub const CHAT_COMPLETIONS_PATH: &str = "/v1/chat/completions"; @@ -87,11 +85,17 @@ mod tests { let done_event = streaming_iter.next(); assert!(done_event.is_some(), "Should get [DONE] event"); let done_event = done_event.unwrap(); - assert!(done_event.is_done(), "[DONE] event should be marked as done"); + assert!( + done_event.is_done(), + "[DONE] event should be marked as done" + ); // After [DONE], iterator should return None let final_event = streaming_iter.next(); - assert!(final_event.is_none(), "Iterator should return None after [DONE]"); + assert!( + final_event.is_none(), + "Iterator should return None after [DONE]" + ); } /// Test AWS Event Stream decoding for Bedrock ConverseStream responses. @@ -130,7 +134,7 @@ mod tests { let mut content_chunks = Vec::new(); // Simulate chunked network arrivals - process as data comes in - let chunk_sizes = vec![50, 100, 75, 200, 150, 300, 500, 1000]; + let chunk_sizes = [50, 100, 75, 200, 150, 300, 500, 1000]; let mut offset = 0; let mut chunk_num = 0; diff --git a/crates/hermesllm/src/providers/id.rs b/crates/hermesllm/src/providers/id.rs index 344a795f..afaabea6 100644 --- a/crates/hermesllm/src/providers/id.rs +++ b/crates/hermesllm/src/providers/id.rs @@ -59,10 +59,9 @@ impl ProviderId { (ProviderId::Anthropic, SupportedAPIsFromClient::AnthropicMessagesAPI(_)) => { SupportedUpstreamAPIs::AnthropicMessagesAPI(AnthropicApi::Messages) } - ( - ProviderId::Anthropic, - SupportedAPIsFromClient::OpenAIChatCompletions(_), - ) => SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions), + (ProviderId::Anthropic, SupportedAPIsFromClient::OpenAIChatCompletions(_)) => { + SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions) + } // Anthropic doesn't support Responses API, fall back to chat completions (ProviderId::Anthropic, SupportedAPIsFromClient::OpenAIResponsesAPI(_)) => { diff --git a/crates/hermesllm/src/providers/request.rs b/crates/hermesllm/src/providers/request.rs index 7cd951c3..d1d85888 100644 --- a/crates/hermesllm/src/providers/request.rs +++ b/crates/hermesllm/src/providers/request.rs @@ -10,6 +10,7 @@ use serde_json::Value; use std::collections::HashMap; use std::error::Error; use std::fmt; +#[allow(clippy::large_enum_variant)] #[derive(Clone, Debug)] pub enum ProviderRequestType { ChatCompletionsRequest(ChatCompletionsRequest), @@ -197,7 +198,9 @@ impl ProviderRequest for ProviderRequestType { impl TryFrom<(&[u8], &SupportedAPIsFromClient)> for ProviderRequestType { type Error = std::io::Error; - fn try_from((bytes, client_api): (&[u8], &SupportedAPIsFromClient)) -> Result { + fn try_from( + (bytes, client_api): (&[u8], &SupportedAPIsFromClient), + ) -> Result { // Use SupportedApi to determine the appropriate request type match client_api { SupportedAPIsFromClient::OpenAIChatCompletions(_) => { @@ -882,7 +885,7 @@ mod tests { ProviderRequestType::BedrockConverse(bedrock_req) => { assert_eq!(bedrock_req.model_id, "gpt-4o"); // Bedrock receives the converted request through ChatCompletions - assert!(!bedrock_req.messages.is_none()); + assert!(bedrock_req.messages.is_some()); } _ => panic!("Expected BedrockConverse variant"), } @@ -913,7 +916,9 @@ mod tests { assert!(result.is_err()); let err = result.unwrap_err(); - assert!(err.message.contains("ResponsesAPI can only be used as a client API")); + assert!(err + .message + .contains("ResponsesAPI can only be used as a client API")); } #[test] @@ -953,7 +958,9 @@ mod tests { assert!(result.is_err()); let err = result.unwrap_err(); - assert!(err.message.contains("ResponsesAPI can only be used as a client API")); + assert!(err + .message + .contains("ResponsesAPI can only be used as a client API")); } #[test] @@ -1023,9 +1030,7 @@ mod tests { role: MessagesRole::User, content: MessagesMessageContent::Single("Hello!".to_string()), }], - system: Some(MessagesSystemPrompt::Single( - "You are helpful".to_string(), - )), + system: Some(MessagesSystemPrompt::Single("You are helpful".to_string())), max_tokens: 100, container: None, mcp_servers: None, @@ -1046,14 +1051,8 @@ mod tests { // Should have system message + user message assert_eq!(messages.len(), 2); - assert_eq!( - messages[0].role, - crate::apis::openai::Role::System - ); - assert_eq!( - messages[1].role, - crate::apis::openai::Role::User - ); + assert_eq!(messages[0].role, crate::apis::openai::Role::System); + assert_eq!(messages[1].role, crate::apis::openai::Role::User); } #[test] @@ -1094,13 +1093,7 @@ mod tests { // Should have system message (instructions) + user message (input) assert_eq!(messages.len(), 2); - assert_eq!( - messages[0].role, - crate::apis::openai::Role::System - ); - assert_eq!( - messages[1].role, - crate::apis::openai::Role::User - ); + assert_eq!(messages[0].role, crate::apis::openai::Role::System); + assert_eq!(messages[1].role, crate::apis::openai::Role::User); } } diff --git a/crates/hermesllm/src/providers/response.rs b/crates/hermesllm/src/providers/response.rs index a2494c6d..5f46f97b 100644 --- a/crates/hermesllm/src/providers/response.rs +++ b/crates/hermesllm/src/providers/response.rs @@ -1,7 +1,3 @@ -use serde::Serialize; -use std::convert::TryFrom; -use std::error::Error; -use std::fmt; use crate::apis::amazon_bedrock::ConverseResponse; use crate::apis::anthropic::MessagesResponse; use crate::apis::openai::ChatCompletionsResponse; @@ -9,14 +5,17 @@ use crate::apis::openai_responses::ResponsesAPIResponse; use crate::clients::endpoints::SupportedAPIsFromClient; use crate::clients::endpoints::SupportedUpstreamAPIs; use crate::providers::id::ProviderId; - +use serde::Serialize; +use std::convert::TryFrom; +use std::error::Error; +use std::fmt; #[derive(Serialize, Debug, Clone)] #[serde(untagged)] pub enum ProviderResponseType { ChatCompletionsResponse(ChatCompletionsResponse), MessagesResponse(MessagesResponse), - ResponsesAPIResponse(ResponsesAPIResponse), + ResponsesAPIResponse(Box), } /// Trait for token usage information @@ -42,7 +41,9 @@ impl ProviderResponse for ProviderResponseType { match self { ProviderResponseType::ChatCompletionsResponse(resp) => resp.usage(), ProviderResponseType::MessagesResponse(resp) => resp.usage(), - ProviderResponseType::ResponsesAPIResponse(resp) => resp.usage.as_ref().map(|u| u as &dyn TokenUsage), + ProviderResponseType::ResponsesAPIResponse(resp) => { + resp.usage.as_ref().map(|u| u as &dyn TokenUsage) + } } } @@ -50,11 +51,13 @@ impl ProviderResponse for ProviderResponseType { match self { ProviderResponseType::ChatCompletionsResponse(resp) => resp.extract_usage_counts(), ProviderResponseType::MessagesResponse(resp) => resp.extract_usage_counts(), - ProviderResponseType::ResponsesAPIResponse(resp) => { - resp.usage.as_ref().map(|u| { - (u.input_tokens as usize, u.output_tokens as usize, u.total_tokens as usize) - }) - } + ProviderResponseType::ResponsesAPIResponse(resp) => resp.usage.as_ref().map(|u| { + ( + u.input_tokens as usize, + u.output_tokens as usize, + u.total_tokens as usize, + ) + }), } } } @@ -156,40 +159,44 @@ impl TryFrom<(&[u8], &SupportedAPIsFromClient, &ProviderId)> for ProviderRespons ) => { let resp: ResponsesAPIResponse = ResponsesAPIResponse::try_from(bytes) .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; - Ok(ProviderResponseType::ResponsesAPIResponse(resp)) + Ok(ProviderResponseType::ResponsesAPIResponse(Box::new(resp))) } ( SupportedUpstreamAPIs::OpenAIChatCompletions(_), SupportedAPIsFromClient::OpenAIResponsesAPI(_), ) => { - let chat_completions_response: ChatCompletionsResponse = ChatCompletionsResponse::try_from(bytes) - .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; + let chat_completions_response: ChatCompletionsResponse = + ChatCompletionsResponse::try_from(bytes) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; // Transform to ResponsesAPI format using the transformer - let responses_resp: ResponsesAPIResponse = chat_completions_response.try_into().map_err(|e| { - std::io::Error::new( - std::io::ErrorKind::InvalidData, - format!("Transformation error: {}", e), - ) - })?; - Ok(ProviderResponseType::ResponsesAPIResponse(responses_resp)) + let responses_resp: ResponsesAPIResponse = + chat_completions_response.try_into().map_err(|e| { + std::io::Error::new( + std::io::ErrorKind::InvalidData, + format!("Transformation error: {}", e), + ) + })?; + Ok(ProviderResponseType::ResponsesAPIResponse(Box::new( + responses_resp, + ))) } ( SupportedUpstreamAPIs::AnthropicMessagesAPI(_), SupportedAPIsFromClient::OpenAIResponsesAPI(_), ) => { - //Chain transform: Anthropic Messages -> OpenAI ChatCompletions -> ResponsesAPI let anthropic_resp: MessagesResponse = serde_json::from_slice(bytes) .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; // Transform to ChatCompletions format using the transformer - let chat_resp: ChatCompletionsResponse = anthropic_resp.try_into().map_err(|e| { - std::io::Error::new( - std::io::ErrorKind::InvalidData, - format!("Transformation error: {}", e), - ) - })?; + let chat_resp: ChatCompletionsResponse = + anthropic_resp.try_into().map_err(|e| { + std::io::Error::new( + std::io::ErrorKind::InvalidData, + format!("Transformation error: {}", e), + ) + })?; let response_api: ResponsesAPIResponse = chat_resp.try_into().map_err(|e| { std::io::Error::new( @@ -197,7 +204,9 @@ impl TryFrom<(&[u8], &SupportedAPIsFromClient, &ProviderId)> for ProviderRespons format!("Transformation error: {}", e), ) })?; - Ok(ProviderResponseType::ResponsesAPIResponse(response_api)) + Ok(ProviderResponseType::ResponsesAPIResponse(Box::new( + response_api, + ))) } ( SupportedUpstreamAPIs::AmazonBedrockConverse(_), @@ -219,10 +228,15 @@ impl TryFrom<(&[u8], &SupportedAPIsFromClient, &ProviderId)> for ProviderRespons let response_api: ResponsesAPIResponse = chat_resp.try_into().map_err(|e| { std::io::Error::new( std::io::ErrorKind::InvalidData, - format!("ChatCompletions to ResponsesAPI transformation error: {}", e), + format!( + "ChatCompletions to ResponsesAPI transformation error: {}", + e + ), ) })?; - Ok(ProviderResponseType::ResponsesAPIResponse(response_api)) + Ok(ProviderResponseType::ResponsesAPIResponse(Box::new( + response_api, + ))) } _ => Err(std::io::Error::new( std::io::ErrorKind::InvalidData, @@ -255,8 +269,8 @@ impl Error for ProviderResponseError { #[cfg(test)] mod tests { use super::*; - use crate::apis::openai::OpenAIApi; use crate::apis::anthropic::AnthropicApi; + use crate::apis::openai::OpenAIApi; use crate::clients::endpoints::SupportedAPIsFromClient; use crate::providers::id::ProviderId; use serde_json::json; diff --git a/crates/hermesllm/src/providers/streaming_response.rs b/crates/hermesllm/src/providers/streaming_response.rs index 55e52f3d..9fc83065 100644 --- a/crates/hermesllm/src/providers/streaming_response.rs +++ b/crates/hermesllm/src/providers/streaming_response.rs @@ -1,18 +1,17 @@ use serde::Serialize; use std::convert::TryFrom; +use crate::apis::amazon_bedrock::ConverseStreamEvent; +use crate::apis::anthropic::MessagesStreamEvent; use crate::apis::openai::ChatCompletionsStreamResponse; use crate::apis::openai_responses::ResponsesAPIStreamEvent; use crate::apis::streaming_shapes::sse::SseEvent; -use crate::apis::amazon_bedrock::ConverseStreamEvent; -use crate::apis::anthropic::MessagesStreamEvent; use crate::apis::streaming_shapes::sse::SseStreamBuffer; use crate::apis::streaming_shapes::{ - anthropic_streaming_buffer::AnthropicMessagesStreamBuffer, - chat_completions_streaming_buffer::OpenAIChatCompletionsStreamBuffer, - passthrough_streaming_buffer::PassthroughStreamBuffer, - responses_api_streaming_buffer::ResponsesAPIStreamBuffer, - }; + anthropic_streaming_buffer::AnthropicMessagesStreamBuffer, + chat_completions_streaming_buffer::OpenAIChatCompletionsStreamBuffer, + passthrough_streaming_buffer::PassthroughStreamBuffer, +}; use crate::clients::endpoints::SupportedAPIsFromClient; use crate::clients::endpoints::SupportedUpstreamAPIs; @@ -28,9 +27,18 @@ pub fn needs_buffering( ) -> bool { match (client_api, upstream_api) { // Same APIs - no buffering needed - (SupportedAPIsFromClient::OpenAIChatCompletions(_), SupportedUpstreamAPIs::OpenAIChatCompletions(_)) => false, - (SupportedAPIsFromClient::AnthropicMessagesAPI(_), SupportedUpstreamAPIs::AnthropicMessagesAPI(_)) => false, - (SupportedAPIsFromClient::OpenAIResponsesAPI(_), SupportedUpstreamAPIs::OpenAIResponsesAPI(_)) => false, + ( + SupportedAPIsFromClient::OpenAIChatCompletions(_), + SupportedUpstreamAPIs::OpenAIChatCompletions(_), + ) => false, + ( + SupportedAPIsFromClient::AnthropicMessagesAPI(_), + SupportedUpstreamAPIs::AnthropicMessagesAPI(_), + ) => false, + ( + SupportedAPIsFromClient::OpenAIResponsesAPI(_), + SupportedUpstreamAPIs::OpenAIResponsesAPI(_), + ) => false, // Different APIs - buffering needed _ => true, @@ -53,15 +61,12 @@ pub fn needs_buffering( /// // Flush to wire /// let bytes = buffer.into_bytes(); /// ``` -impl TryFrom<(&SupportedAPIsFromClient, &SupportedUpstreamAPIs)> - for SseStreamBuffer -{ +impl TryFrom<(&SupportedAPIsFromClient, &SupportedUpstreamAPIs)> for SseStreamBuffer { type Error = Box; fn try_from( (client_api, upstream_api): (&SupportedAPIsFromClient, &SupportedUpstreamAPIs), ) -> Result { - // If APIs match, use passthrough - no buffering/transformation needed if !needs_buffering(client_api, upstream_api) { return Ok(SseStreamBuffer::Passthrough(PassthroughStreamBuffer::new())); @@ -69,14 +74,14 @@ impl TryFrom<(&SupportedAPIsFromClient, &SupportedUpstreamAPIs)> // APIs differ - use appropriate buffer for client API match client_api { - SupportedAPIsFromClient::OpenAIChatCompletions(_) => { - Ok(SseStreamBuffer::OpenAIChatCompletions(OpenAIChatCompletionsStreamBuffer::new())) - } - SupportedAPIsFromClient::AnthropicMessagesAPI(_) => { - Ok(SseStreamBuffer::AnthropicMessages(AnthropicMessagesStreamBuffer::new())) - } + SupportedAPIsFromClient::OpenAIChatCompletions(_) => Ok( + SseStreamBuffer::OpenAIChatCompletions(OpenAIChatCompletionsStreamBuffer::new()), + ), + SupportedAPIsFromClient::AnthropicMessagesAPI(_) => Ok( + SseStreamBuffer::AnthropicMessages(AnthropicMessagesStreamBuffer::new()), + ), SupportedAPIsFromClient::OpenAIResponsesAPI(_) => { - Ok(SseStreamBuffer::OpenAIResponses(ResponsesAPIStreamBuffer::new())) + Ok(SseStreamBuffer::OpenAIResponses(Box::default())) } } } @@ -88,11 +93,12 @@ impl TryFrom<(&SupportedAPIsFromClient, &SupportedUpstreamAPIs)> #[derive(Serialize, Debug, Clone)] #[serde(untagged)] +#[allow(clippy::large_enum_variant)] pub enum ProviderStreamResponseType { ChatCompletionsStreamResponse(ChatCompletionsStreamResponse), MessagesStreamEvent(MessagesStreamEvent), ConverseStreamEvent(ConverseStreamEvent), - ResponseAPIStreamEvent(ResponsesAPIStreamEvent) + ResponseAPIStreamEvent(Box), } pub trait ProviderStreamResponse: Send + Sync { @@ -145,12 +151,11 @@ impl ProviderStreamResponse for ProviderStreamResponseType { ProviderStreamResponseType::ResponseAPIStreamEvent(resp) => resp.event_type(), } } - } -impl Into for ProviderStreamResponseType { - fn into(self) -> String { - match self { +impl From for String { + fn from(val: ProviderStreamResponseType) -> String { + match val { ProviderStreamResponseType::MessagesStreamEvent(event) => { // Use the Into implementation for proper SSE formatting with event lines event.into() @@ -161,27 +166,36 @@ impl Into for ProviderStreamResponseType { } ProviderStreamResponseType::ResponseAPIStreamEvent(event) => { // Use the Into implementation for proper SSE formatting with event lines - event.into() + // Clone to work around Box ownership + let cloned = (*event).clone(); + cloned.into() } ProviderStreamResponseType::ChatCompletionsStreamResponse(_) => { // For OpenAI, use simple data line format - let json = serde_json::to_string(&self).unwrap_or_default(); + let json = serde_json::to_string(&val).unwrap_or_default(); format!("data: {}\n\n", json) } } } } - // Stream response transformation logic for client API compatibility -impl TryFrom<(&[u8], &SupportedAPIsFromClient, &SupportedUpstreamAPIs)> for ProviderStreamResponseType { +impl TryFrom<(&[u8], &SupportedAPIsFromClient, &SupportedUpstreamAPIs)> + for ProviderStreamResponseType +{ type Error = Box; fn try_from( - (bytes, client_api, upstream_api): (&[u8], &SupportedAPIsFromClient, &SupportedUpstreamAPIs), + (bytes, client_api, upstream_api): ( + &[u8], + &SupportedAPIsFromClient, + &SupportedUpstreamAPIs, + ), ) -> Result { // Special case: Handle [DONE] marker for OpenAI -> Anthropic conversion - if bytes == b"[DONE]" && matches!(client_api, SupportedAPIsFromClient::AnthropicMessagesAPI(_)) { + if bytes == b"[DONE]" + && matches!(client_api, SupportedAPIsFromClient::AnthropicMessagesAPI(_)) + { return Ok(ProviderStreamResponseType::MessagesStreamEvent( crate::apis::anthropic::MessagesStreamEvent::MessageStop, )); @@ -214,9 +228,9 @@ impl TryFrom<(&[u8], &SupportedAPIsFromClient, &SupportedUpstreamAPIs)> for Prov ) => { let openai_resp: crate::apis::openai::ChatCompletionsStreamResponse = serde_json::from_slice(bytes)?; - let responses_resp = openai_resp.try_into()?; + let responses_resp: ResponsesAPIStreamEvent = openai_resp.try_into()?; Ok(ProviderStreamResponseType::ResponseAPIStreamEvent( - responses_resp, + Box::new(responses_resp), )) } @@ -267,10 +281,11 @@ impl TryFrom<(&[u8], &SupportedAPIsFromClient, &SupportedUpstreamAPIs)> for Prov // Chain: Bedrock -> ChatCompletions -> ResponsesAPI let bedrock_resp: crate::apis::amazon_bedrock::ConverseStreamEvent = serde_json::from_slice(bytes)?; - let chat_resp: crate::apis::openai::ChatCompletionsStreamResponse = bedrock_resp.try_into()?; - let responses_resp = chat_resp.try_into()?; + let chat_resp: crate::apis::openai::ChatCompletionsStreamResponse = + bedrock_resp.try_into()?; + let responses_resp: ResponsesAPIStreamEvent = chat_resp.try_into()?; Ok(ProviderStreamResponseType::ResponseAPIStreamEvent( - responses_resp, + Box::new(responses_resp), )) } _ => Err(std::io::Error::new( @@ -287,7 +302,11 @@ impl TryFrom<(SseEvent, &SupportedAPIsFromClient, &SupportedUpstreamAPIs)> for S type Error = Box; fn try_from( - (sse_event, client_api, upstream_api): (SseEvent, &SupportedAPIsFromClient, &SupportedUpstreamAPIs), + (sse_event, client_api, upstream_api): ( + SseEvent, + &SupportedAPIsFromClient, + &SupportedUpstreamAPIs, + ), ) -> Result { // Create a new transformed event based on the original let mut transformed_event = sse_event; @@ -296,7 +315,11 @@ impl TryFrom<(SseEvent, &SupportedAPIsFromClient, &SupportedUpstreamAPIs)> for S if transformed_event.is_done() { // For OpenAI client APIs (ChatCompletions and ResponsesAPI), keep [DONE] as-is // For Anthropic client API, it will be transformed via ProviderStreamResponseType - if matches!(client_api, SupportedAPIsFromClient::OpenAIChatCompletions(_) | SupportedAPIsFromClient::OpenAIResponsesAPI(_)) { + if matches!( + client_api, + SupportedAPIsFromClient::OpenAIChatCompletions(_) + | SupportedAPIsFromClient::OpenAIResponsesAPI(_) + ) { // Keep the [DONE] marker as-is for OpenAI clients transformed_event.sse_transformed_lines = "data: [DONE]".to_string(); return Ok(transformed_event); @@ -328,7 +351,7 @@ impl TryFrom<(SseEvent, &SupportedAPIsFromClient, &SupportedUpstreamAPIs)> for S // OpenAI clients don't expect separate event: lines // Suppress upstream Anthropic event-only lines if transformed_event.is_event_only() && transformed_event.event.is_some() { - transformed_event.sse_transformed_lines = format!("\n"); + transformed_event.sse_transformed_lines = "\n".to_string(); } } _ => { @@ -345,7 +368,8 @@ impl TryFrom<(SseEvent, &SupportedAPIsFromClient, &SupportedUpstreamAPIs)> for S ( SupportedAPIsFromClient::AnthropicMessagesAPI(_), SupportedUpstreamAPIs::AnthropicMessagesAPI(_), - ) | ( + ) + | ( SupportedAPIsFromClient::OpenAIResponsesAPI(_), SupportedUpstreamAPIs::OpenAIResponsesAPI(_), ) => { @@ -415,7 +439,7 @@ impl openai_event, )) } - ( + ( SupportedUpstreamAPIs::AmazonBedrockConverseStream(_), SupportedAPIsFromClient::OpenAIResponsesAPI(_), ) => { @@ -428,7 +452,7 @@ impl openai_chat_completions_event.try_into()?; Ok(ProviderStreamResponseType::ResponseAPIStreamEvent( - openai_responses_api_event, + Box::new(openai_responses_api_event), )) } _ => Err("Unsupported API combination for event-stream decoding".into()), @@ -445,11 +469,11 @@ impl mod tests { use super::*; use crate::apis::streaming_shapes::amazon_bedrock_binary_frame::BedrockBinaryFrameDecoder; - use crate::clients::endpoints::SupportedAPIsFromClient; use crate::apis::streaming_shapes::sse::SseStreamIter; + use crate::clients::endpoints::SupportedAPIsFromClient; use serde_json::json; - #[test] + #[test] fn test_sse_event_parsing() { // Test valid SSE data line let line = "data: {\"id\":\"test\",\"object\":\"chat.completion.chunk\"}\n\n"; @@ -792,7 +816,7 @@ mod tests { // Simulate chunked network arrivals with realistic chunk sizes // Using varying chunk sizes to test partial frame handling let mut buffer = BytesMut::new(); - let chunk_size_pattern = vec![500, 1000, 750, 1200, 800, 1500]; + let chunk_size_pattern = [500, 1000, 750, 1200, 800, 1500]; let mut offset = 0; let mut total_frames = 0; let mut chunk_num = 0; @@ -837,7 +861,7 @@ mod tests { ); } - #[test] + #[test] fn test_bedrock_decoded_frame_to_provider_response() { test_bedrock_conversion(false); } @@ -879,8 +903,9 @@ mod tests { let mut decoder = BedrockBinaryFrameDecoder::new(&mut buffer); - let client_api = - SupportedAPIsFromClient::AnthropicMessagesAPI(crate::apis::anthropic::AnthropicApi::Messages); + let client_api = SupportedAPIsFromClient::AnthropicMessagesAPI( + crate::apis::anthropic::AnthropicApi::Messages, + ); let upstream_api = SupportedUpstreamAPIs::AmazonBedrockConverseStream( crate::apis::amazon_bedrock::AmazonBedrockApi::ConverseStream, ); @@ -966,8 +991,9 @@ mod tests { let mut decoder = BedrockBinaryFrameDecoder::new(&mut buffer); - let client_api = - SupportedAPIsFromClient::AnthropicMessagesAPI(crate::apis::anthropic::AnthropicApi::Messages); + let client_api = SupportedAPIsFromClient::AnthropicMessagesAPI( + crate::apis::anthropic::AnthropicApi::Messages, + ); let upstream_api = SupportedUpstreamAPIs::AmazonBedrockConverseStream( crate::apis::amazon_bedrock::AmazonBedrockApi::ConverseStream, ); @@ -1051,7 +1077,6 @@ mod tests { ); } - #[test] fn test_sse_event_transformation_openai_to_anthropic_message_delta() { use crate::apis::anthropic::AnthropicApi; @@ -1079,8 +1104,8 @@ mod tests { let sse_event = SseEvent { data: Some(openai_stream_chunk.to_string()), event: None, - raw_line: format!("data: {}", openai_stream_chunk.to_string()), - sse_transformed_lines: format!("data: {}", openai_stream_chunk.to_string()), + raw_line: format!("data: {}", openai_stream_chunk), + sse_transformed_lines: format!("data: {}", openai_stream_chunk), provider_stream_response: None, }; @@ -1101,7 +1126,8 @@ mod tests { // Verify the event was transformed to Anthropic format // This should contain message_delta with stop_reason and usage assert!( - buffer.contains("event: message_delta") || buffer.contains("\"type\":\"message_delta\""), + buffer.contains("event: message_delta") + || buffer.contains("\"type\":\"message_delta\""), "Should contain message_delta in transformed event" ); @@ -1134,8 +1160,8 @@ mod tests { let sse_event = SseEvent { data: Some(openai_stream_chunk.to_string()), event: None, - raw_line: format!("data: {}", openai_stream_chunk.to_string()), - sse_transformed_lines: format!("data: {}", openai_stream_chunk.to_string()), + raw_line: format!("data: {}", openai_stream_chunk), + sse_transformed_lines: format!("data: {}", openai_stream_chunk), provider_stream_response: None, }; @@ -1223,8 +1249,8 @@ mod tests { let sse_event = SseEvent { data: Some(anthropic_event.to_string()), event: None, - raw_line: format!("data: {}", anthropic_event.to_string()), - sse_transformed_lines: format!("data: {}", anthropic_event.to_string()), + raw_line: format!("data: {}", anthropic_event), + sse_transformed_lines: format!("data: {}", anthropic_event), provider_stream_response: None, }; @@ -1314,8 +1340,8 @@ mod tests { let sse_event = SseEvent { data: Some(openai_stream_chunk.to_string()), event: None, - raw_line: format!("data: {}", openai_stream_chunk.to_string()), - sse_transformed_lines: format!("data: {}", openai_stream_chunk.to_string()), + raw_line: format!("data: {}", openai_stream_chunk), + sse_transformed_lines: format!("data: {}", openai_stream_chunk), provider_stream_response: None, }; diff --git a/crates/hermesllm/src/transforms/lib.rs b/crates/hermesllm/src/transforms/lib.rs index 53a7621e..a44f8d79 100644 --- a/crates/hermesllm/src/transforms/lib.rs +++ b/crates/hermesllm/src/transforms/lib.rs @@ -11,11 +11,11 @@ pub trait ExtractText { /// Trait for utility functions on content collections pub trait ContentUtils { fn extract_tool_calls(&self) -> Result>, TransformError>; - fn split_for_openai( - &self, - ) -> Result<(Vec, Vec, Vec<(String, String, bool)>), TransformError>; + fn split_for_openai(&self) -> Result; } +pub type SplitForOpenAIResult = (Vec, Vec, Vec<(String, String, bool)>); + /// Helper to create a current unix timestamp pub fn current_timestamp() -> u64 { SystemTime::now() diff --git a/crates/hermesllm/src/transforms/request/from_anthropic.rs b/crates/hermesllm/src/transforms/request/from_anthropic.rs index 9dedc313..c07be4e5 100644 --- a/crates/hermesllm/src/transforms/request/from_anthropic.rs +++ b/crates/hermesllm/src/transforms/request/from_anthropic.rs @@ -38,7 +38,7 @@ impl TryFrom for ChatCompletionsRequest { } // Convert tools and tool choice - let openai_tools = req.tools.map(|tools| convert_anthropic_tools(tools)); + let openai_tools = req.tools.map(convert_anthropic_tools); let (openai_tool_choice, parallel_tool_calls) = convert_anthropic_tool_choice(req.tool_choice); @@ -218,18 +218,18 @@ impl TryFrom for Vec { } // Role Conversions -impl Into for MessagesRole { - fn into(self) -> Role { - match self { +impl From for Role { + fn from(val: MessagesRole) -> Self { + match val { MessagesRole::User => Role::User, MessagesRole::Assistant => Role::Assistant, } } } -impl Into for FinishReason { - fn into(self) -> MessagesStopReason { - match self { +impl From for MessagesStopReason { + fn from(val: FinishReason) -> Self { + match val { FinishReason::Stop => MessagesStopReason::EndTurn, FinishReason::Length => MessagesStopReason::MaxTokens, FinishReason::ToolCalls => MessagesStopReason::ToolUse, @@ -239,11 +239,11 @@ impl Into for FinishReason { } } -impl Into for Usage { - fn into(self) -> MessagesUsage { +impl From for MessagesUsage { + fn from(val: Usage) -> Self { MessagesUsage { - input_tokens: self.prompt_tokens, - output_tokens: self.completion_tokens, + input_tokens: val.prompt_tokens, + output_tokens: val.completion_tokens, cache_creation_input_tokens: None, cache_read_input_tokens: None, } @@ -251,9 +251,9 @@ impl Into for Usage { } // System Prompt Conversions -impl Into for MessagesSystemPrompt { - fn into(self) -> Message { - let system_content = match self { +impl From for Message { + fn from(val: MessagesSystemPrompt) -> Self { + let system_content = match val { MessagesSystemPrompt::Single(text) => MessageContent::Text(text), MessagesSystemPrompt::Blocks(blocks) => MessageContent::Text(blocks.extract_text()), }; @@ -384,12 +384,8 @@ impl TryFrom for BedrockMessage { ToolResultContent::Blocks(blocks) => { let mut result_blocks = Vec::new(); for result_block in blocks { - match result_block { - crate::apis::anthropic::MessagesContentBlock::Text { text, .. } => { - result_blocks.push(ToolResultContentBlock::Text { text }); - } - // For now, skip other content types in tool results - _ => {} + if let crate::apis::anthropic::MessagesContentBlock::Text { text, .. } = result_block { + result_blocks.push(ToolResultContentBlock::Text { text }); } } result_blocks diff --git a/crates/hermesllm/src/transforms/request/from_openai.rs b/crates/hermesllm/src/transforms/request/from_openai.rs index 27366f4d..2a133041 100644 --- a/crates/hermesllm/src/transforms/request/from_openai.rs +++ b/crates/hermesllm/src/transforms/request/from_openai.rs @@ -14,7 +14,8 @@ use crate::apis::openai::{ }; use crate::apis::openai_responses::{ - ResponsesAPIRequest, InputContent, InputItem, InputParam, MessageRole, Modality, ReasoningEffort, Tool as ResponsesTool, ToolChoice as ResponsesToolChoice + InputContent, InputItem, InputParam, MessageRole, Modality, ReasoningEffort, + ResponsesAPIRequest, Tool as ResponsesTool, ToolChoice as ResponsesToolChoice, }; use crate::clients::TransformError; use crate::transforms::lib::ExtractText; @@ -27,9 +28,9 @@ type AnthropicMessagesRequest = MessagesRequest; // MAIN REQUEST TRANSFORMATIONS // ============================================================================ -impl Into for Message { - fn into(self) -> MessagesSystemPrompt { - let system_text = match self.content { +impl From for MessagesSystemPrompt { + fn from(val: Message) -> Self { + let system_text = match val.content { MessageContent::Text(text) => text, MessageContent::Parts(parts) => parts.extract_text(), }; @@ -163,7 +164,7 @@ impl TryFrom for BedrockMessage { let has_tool_calls = message .tool_calls .as_ref() - .map_or(false, |calls| !calls.is_empty()); + .is_some_and(|calls| !calls.is_empty()); // Add text content if it's non-empty, or if we have no tool calls (to avoid empty content) if !text_content.is_empty() { @@ -252,7 +253,6 @@ impl TryFrom for ChatCompletionsRequest { type Error = TransformError; fn try_from(req: ResponsesAPIRequest) -> Result { - // Convert input to messages let messages = match req.input { InputParam::Text(text) => { @@ -282,50 +282,27 @@ impl TryFrom for ChatCompletionsRequest { // Convert each input item for item in items { - match item { - InputItem::Message(input_msg) => { - let role = match input_msg.role { - MessageRole::User => Role::User, - MessageRole::Assistant => Role::Assistant, - MessageRole::System => Role::System, - MessageRole::Developer => Role::System, // Map developer to system - }; + if let InputItem::Message(input_msg) = item { + let role = match input_msg.role { + MessageRole::User => Role::User, + MessageRole::Assistant => Role::Assistant, + MessageRole::System => Role::System, + MessageRole::Developer => Role::System, // Map developer to system + }; - // Convert content based on MessageContent type - let content = match &input_msg.content { - crate::apis::openai_responses::MessageContent::Text(text) => { - // Simple text content - MessageContent::Text(text.clone()) - } - crate::apis::openai_responses::MessageContent::Items(content_items) => { - // Check if it's a single text item (can use simple text format) - if content_items.len() == 1 { - if let InputContent::InputText { text } = &content_items[0] { - MessageContent::Text(text.clone()) - } else { - // Single non-text item - use parts format - MessageContent::Parts( - content_items.iter() - .filter_map(|c| match c { - InputContent::InputText { text } => { - Some(crate::apis::openai::ContentPart::Text { text: text.clone() }) - } - InputContent::InputImage { image_url, .. } => { - Some(crate::apis::openai::ContentPart::ImageUrl { - image_url: crate::apis::openai::ImageUrl { - url: image_url.clone(), - detail: None, - } - }) - } - InputContent::InputFile { .. } => None, // Skip files for now - InputContent::InputAudio { .. } => None, // Skip audio for now - }) - .collect() - ) - } + // Convert content based on MessageContent type + let content = match &input_msg.content { + crate::apis::openai_responses::MessageContent::Text(text) => { + // Simple text content + MessageContent::Text(text.clone()) + } + crate::apis::openai_responses::MessageContent::Items(content_items) => { + // Check if it's a single text item (can use simple text format) + if content_items.len() == 1 { + if let InputContent::InputText { text } = &content_items[0] { + MessageContent::Text(text.clone()) } else { - // Multiple content items - convert to parts + // Single non-text item - use parts format MessageContent::Parts( content_items.iter() .filter_map(|c| match c { @@ -346,20 +323,41 @@ impl TryFrom for ChatCompletionsRequest { .collect() ) } + } else { + // Multiple content items - convert to parts + MessageContent::Parts( + content_items + .iter() + .filter_map(|c| match c { + InputContent::InputText { text } => { + Some(crate::apis::openai::ContentPart::Text { + text: text.clone(), + }) + } + InputContent::InputImage { image_url, .. } => Some( + crate::apis::openai::ContentPart::ImageUrl { + image_url: crate::apis::openai::ImageUrl { + url: image_url.clone(), + detail: None, + }, + }, + ), + InputContent::InputFile { .. } => None, // Skip files for now + InputContent::InputAudio { .. } => None, // Skip audio for now + }) + .collect(), + ) } - }; + } + }; - converted_messages.push(Message { - role, - content, - name: None, - tool_call_id: None, - tool_calls: None, - }); - } - // Skip non-message items (references, outputs) for now - // These would need special handling in chat completions format - _ => {} + converted_messages.push(Message { + role, + content, + name: None, + tool_call_id: None, + tool_calls: None, + }); } } @@ -474,7 +472,7 @@ impl TryFrom for AnthropicMessagesRequest { } // Convert tools and tool choice - let anthropic_tools = req.tools.map(|tools| convert_openai_tools(tools)); + let anthropic_tools = req.tools.map(convert_openai_tools); let anthropic_tool_choice = convert_openai_tool_choice(req.tool_choice, req.parallel_tool_calls); diff --git a/crates/hermesllm/src/transforms/response/output_to_input.rs b/crates/hermesllm/src/transforms/response/output_to_input.rs index 8ab08205..e62f32b8 100644 --- a/crates/hermesllm/src/transforms/response/output_to_input.rs +++ b/crates/hermesllm/src/transforms/response/output_to_input.rs @@ -13,18 +13,14 @@ use crate::apis::openai_responses::{ pub fn convert_responses_output_to_input_items(output: &OutputItem) -> Option { match output { // Convert output messages to input messages - OutputItem::Message { - role, content, .. - } => { + OutputItem::Message { role, content, .. } => { let input_content: Vec = content .iter() .filter_map(|c| match c { - OutputContent::OutputText { text, .. } => Some(InputContent::InputText { - text: text.clone(), - }), - OutputContent::OutputAudio { - data, .. - } => Some(InputContent::InputAudio { + OutputContent::OutputText { text, .. } => { + Some(InputContent::InputText { text: text.clone() }) + } + OutputContent::OutputAudio { data, .. } => Some(InputContent::InputAudio { data: data.clone(), format: None, // Format not preserved in output }), @@ -84,7 +80,7 @@ pub fn outputs_to_inputs(outputs: &[OutputItem]) -> Vec { #[cfg(test)] mod tests { use super::*; - use crate::apis::openai_responses::{OutputItemStatus}; + use crate::apis::openai_responses::OutputItemStatus; #[test] fn test_output_message_to_input() { @@ -135,14 +131,12 @@ mod tests { InputItem::Message(msg) => { assert!(matches!(msg.role, MessageRole::Assistant)); match &msg.content { - MessageContent::Items(items) => { - match &items[0] { - InputContent::InputText { text } => { - assert!(text.contains("get_weather")); - } - _ => panic!("Expected InputText"), + MessageContent::Items(items) => match &items[0] { + InputContent::InputText { text } => { + assert!(text.contains("get_weather")); } - } + _ => panic!("Expected InputText"), + }, _ => panic!("Expected MessageContent::Items"), } } diff --git a/crates/hermesllm/src/transforms/response/to_anthropic.rs b/crates/hermesllm/src/transforms/response/to_anthropic.rs index 0326fdb3..7eb34d56 100644 --- a/crates/hermesllm/src/transforms/response/to_anthropic.rs +++ b/crates/hermesllm/src/transforms/response/to_anthropic.rs @@ -1,7 +1,6 @@ use crate::apis::amazon_bedrock::{ConverseOutput, ConverseResponse, StopReason}; use crate::apis::anthropic::{ - MessagesContentBlock, MessagesResponse, - MessagesRole, MessagesStopReason, MessagesUsage, + MessagesContentBlock, MessagesResponse, MessagesRole, MessagesStopReason, MessagesUsage, }; use crate::apis::openai::ChatCompletionsResponse; use crate::clients::TransformError; @@ -115,7 +114,6 @@ impl TryFrom for MessagesResponse { } } - /// Convert Bedrock Message to Anthropic content blocks /// /// This function handles the conversion between Amazon Bedrock Converse API format diff --git a/crates/hermesllm/src/transforms/response/to_openai.rs b/crates/hermesllm/src/transforms/response/to_openai.rs index ee526c71..9e3276c9 100644 --- a/crates/hermesllm/src/transforms/response/to_openai.rs +++ b/crates/hermesllm/src/transforms/response/to_openai.rs @@ -1,9 +1,5 @@ -use crate::apis::amazon_bedrock::{ - ConverseOutput, ConverseResponse, StopReason, -}; -use crate::apis::anthropic::{ - MessagesContentBlock, MessagesResponse, MessagesUsage, -}; +use crate::apis::amazon_bedrock::{ConverseOutput, ConverseResponse, StopReason}; +use crate::apis::anthropic::{MessagesContentBlock, MessagesResponse, MessagesUsage}; use crate::apis::openai::{ ChatCompletionsResponse, Choice, FinishReason, MessageContent, ResponseMessage, Role, Usage, }; @@ -16,12 +12,12 @@ use crate::transforms::lib::*; // ============================================================================ // Usage Conversions -impl Into for MessagesUsage { - fn into(self) -> Usage { +impl From for Usage { + fn from(val: MessagesUsage) -> Self { Usage { - prompt_tokens: self.input_tokens, - completion_tokens: self.output_tokens, - total_tokens: self.input_tokens + self.output_tokens, + prompt_tokens: val.input_tokens, + completion_tokens: val.output_tokens, + total_tokens: val.input_tokens + val.output_tokens, prompt_tokens_details: None, completion_tokens_details: None, } @@ -203,7 +199,6 @@ impl TryFrom for ResponsesAPIResponse { } } - impl TryFrom for ChatCompletionsResponse { type Error = TransformError; @@ -415,7 +410,6 @@ fn convert_anthropic_content_to_openai( Ok(MessageContent::Text(text_parts.join("\n"))) } - #[cfg(test)] mod tests { use super::*; @@ -994,8 +988,15 @@ mod tests { let responses_api: ResponsesAPIResponse = chat_response.try_into().unwrap(); // Response ID should be generated with resp_ prefix - assert!(responses_api.id.starts_with("resp_"), "Response ID should start with 'resp_'"); - assert_eq!(responses_api.id.len(), 37, "Response ID should be resp_ + 32 char UUID"); + assert!( + responses_api.id.starts_with("resp_"), + "Response ID should start with 'resp_'" + ); + assert_eq!( + responses_api.id.len(), + 37, + "Response ID should be resp_ + 32 char UUID" + ); assert_eq!(responses_api.object, "response"); assert_eq!(responses_api.model, "gpt-4"); @@ -1008,11 +1009,7 @@ mod tests { // Check output items assert_eq!(responses_api.output.len(), 1); match &responses_api.output[0] { - OutputItem::Message { - role, - content, - .. - } => { + OutputItem::Message { role, content, .. } => { assert_eq!(role, "assistant"); assert_eq!(content.len(), 1); match &content[0] { @@ -1163,6 +1160,9 @@ mod tests { } // Verify status is Completed for tool_calls finish reason - assert!(matches!(responses_api.status, crate::apis::openai_responses::ResponseStatus::Completed)); + assert!(matches!( + responses_api.status, + crate::apis::openai_responses::ResponseStatus::Completed + )); } } diff --git a/crates/hermesllm/src/transforms/response_streaming/to_anthropic_streaming.rs b/crates/hermesllm/src/transforms/response_streaming/to_anthropic_streaming.rs index b8cac631..5dbf09ef 100644 --- a/crates/hermesllm/src/transforms/response_streaming/to_anthropic_streaming.rs +++ b/crates/hermesllm/src/transforms/response_streaming/to_anthropic_streaming.rs @@ -1,12 +1,9 @@ -use crate::apis::amazon_bedrock::{ - ContentBlockDelta, ConverseStreamEvent, -}; +use crate::apis::amazon_bedrock::{ContentBlockDelta, ConverseStreamEvent}; use crate::apis::anthropic::{ - MessagesContentBlock, MessagesContentDelta, MessagesMessageDelta, - MessagesRole, MessagesStopReason, MessagesStreamEvent, MessagesStreamMessage, MessagesUsage, -}; -use crate::apis::openai::{ ChatCompletionsStreamResponse, ToolCallDelta, + MessagesContentBlock, MessagesContentDelta, MessagesMessageDelta, MessagesRole, + MessagesStopReason, MessagesStreamEvent, MessagesStreamMessage, MessagesUsage, }; +use crate::apis::openai::{ChatCompletionsStreamResponse, ToolCallDelta}; use crate::clients::TransformError; use serde_json::Value; @@ -86,10 +83,10 @@ impl TryFrom for MessagesStreamEvent { } } -impl Into for MessagesStreamEvent { - fn into(self) -> String { - let transformed_json = serde_json::to_string(&self).unwrap_or_default(); - let event_type = match &self { +impl From for String { + fn from(val: MessagesStreamEvent) -> Self { + let transformed_json = serde_json::to_string(&val).unwrap_or_default(); + let event_type = match &val { MessagesStreamEvent::MessageStart { .. } => "message_start", MessagesStreamEvent::ContentBlockStart { .. } => "content_block_start", MessagesStreamEvent::ContentBlockDelta { .. } => "content_block_delta", @@ -194,10 +191,18 @@ impl TryFrom for MessagesStreamEvent { let anthropic_stop_reason = match stop_event.stop_reason { crate::apis::amazon_bedrock::StopReason::EndTurn => MessagesStopReason::EndTurn, crate::apis::amazon_bedrock::StopReason::ToolUse => MessagesStopReason::ToolUse, - crate::apis::amazon_bedrock::StopReason::MaxTokens => MessagesStopReason::MaxTokens, - crate::apis::amazon_bedrock::StopReason::StopSequence => MessagesStopReason::EndTurn, - crate::apis::amazon_bedrock::StopReason::GuardrailIntervened => MessagesStopReason::Refusal, - crate::apis::amazon_bedrock::StopReason::ContentFiltered => MessagesStopReason::Refusal, + crate::apis::amazon_bedrock::StopReason::MaxTokens => { + MessagesStopReason::MaxTokens + } + crate::apis::amazon_bedrock::StopReason::StopSequence => { + MessagesStopReason::EndTurn + } + crate::apis::amazon_bedrock::StopReason::GuardrailIntervened => { + MessagesStopReason::Refusal + } + crate::apis::amazon_bedrock::StopReason::ContentFiltered => { + MessagesStopReason::Refusal + } }; Ok(MessagesStreamEvent::MessageDelta { diff --git a/crates/hermesllm/src/transforms/response_streaming/to_openai_streaming.rs b/crates/hermesllm/src/transforms/response_streaming/to_openai_streaming.rs index ca3d049b..328317bc 100644 --- a/crates/hermesllm/src/transforms/response_streaming/to_openai_streaming.rs +++ b/crates/hermesllm/src/transforms/response_streaming/to_openai_streaming.rs @@ -1,8 +1,10 @@ -use crate::apis::amazon_bedrock::{ ConverseStreamEvent, StopReason}; +use crate::apis::amazon_bedrock::{ConverseStreamEvent, StopReason}; use crate::apis::anthropic::{ - MessagesContentBlock, MessagesContentDelta, MessagesStopReason, MessagesStreamEvent}; -use crate::apis::openai::{ ChatCompletionsStreamResponse,FinishReason, - FunctionCallDelta, MessageDelta, Role, StreamChoice, ToolCallDelta, Usage, + MessagesContentBlock, MessagesContentDelta, MessagesStopReason, MessagesStreamEvent, +}; +use crate::apis::openai::{ + ChatCompletionsStreamResponse, FinishReason, FunctionCallDelta, MessageDelta, Role, + StreamChoice, ToolCallDelta, Usage, }; use crate::apis::openai_responses::ResponsesAPIStreamEvent; @@ -58,11 +60,14 @@ impl TryFrom for ChatCompletionsStreamResponse { None, )), - MessagesStreamEvent::ContentBlockStart { content_block, index } => { - convert_content_block_start(content_block, index) - } + MessagesStreamEvent::ContentBlockStart { + content_block, + index, + } => convert_content_block_start(content_block, index), - MessagesStreamEvent::ContentBlockDelta { delta, index } => convert_content_delta(delta, index), + MessagesStreamEvent::ContentBlockDelta { delta, index } => { + convert_content_delta(delta, index) + } MessagesStreamEvent::ContentBlockStop { .. } => Ok(create_empty_openai_chunk()), @@ -427,9 +432,9 @@ fn create_empty_openai_chunk() -> ChatCompletionsStreamResponse { } // Stop Reason Conversions -impl Into for MessagesStopReason { - fn into(self) -> FinishReason { - match self { +impl From for FinishReason { + fn from(val: MessagesStopReason) -> Self { + match val { MessagesStopReason::EndTurn => FinishReason::Stop, MessagesStopReason::MaxTokens => FinishReason::Length, MessagesStopReason::StopSequence => FinishReason::Stop, @@ -456,34 +461,37 @@ impl TryFrom for ResponsesAPIStreamEvent { if let Some(tool_call) = tool_calls.first() { // Extract call_id and name if available (metadata from initial event) let call_id = tool_call.id.clone(); - let function_name = tool_call.function.as_ref() - .and_then(|f| f.name.clone()); + let function_name = tool_call.function.as_ref().and_then(|f| f.name.clone()); // Check if we have function metadata (name, id) if let Some(function) = &tool_call.function { // If we have arguments delta, return that if let Some(args) = &function.arguments { - return Ok(ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDelta { - output_index: choice.index as i32, - item_id: "".to_string(), // Buffer will fill this - delta: args.clone(), - sequence_number: 0, // Buffer will fill this - call_id, - name: function_name, - }); + return Ok( + ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDelta { + output_index: choice.index as i32, + item_id: "".to_string(), // Buffer will fill this + delta: args.clone(), + sequence_number: 0, // Buffer will fill this + call_id, + name: function_name, + }, + ); } // If we have function name but no arguments yet (initial tool call event) // Return an empty arguments delta so the buffer knows to create the item if function.name.is_some() { - return Ok(ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDelta { - output_index: choice.index as i32, - item_id: "".to_string(), // Buffer will fill this - delta: "".to_string(), // Empty delta signals this is the initial event - sequence_number: 0, // Buffer will fill this - call_id, - name: function_name, - }); + return Ok( + ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDelta { + output_index: choice.index as i32, + item_id: "".to_string(), // Buffer will fill this + delta: "".to_string(), // Empty delta signals this is the initial event + sequence_number: 0, // Buffer will fill this + call_id, + name: function_name, + }, + ); } } } diff --git a/crates/llm_gateway/src/stream_context.rs b/crates/llm_gateway/src/stream_context.rs index 91182e52..066fef7f 100644 --- a/crates/llm_gateway/src/stream_context.rs +++ b/crates/llm_gateway/src/stream_context.rs @@ -94,8 +94,8 @@ impl StreamContext { fn request_identifier(&self) -> String { self.request_id .as_ref() - .filter(|id| !id.is_empty()) // Filter out empty strings - .map(|id| id.clone()) + .filter(|id| !id.is_empty()) + .cloned() .unwrap_or_else(|| "NO_REQUEST_ID".to_string()) } fn llm_provider(&self) -> &LlmProvider { @@ -504,7 +504,7 @@ impl StreamContext { // Get accumulated bytes from buffer and return match self.sse_buffer.as_mut() { Some(buffer) => { - let bytes = buffer.into_bytes(); + let bytes = buffer.to_bytes(); if !bytes.is_empty() { let content = String::from_utf8_lossy(&bytes); debug!( @@ -623,7 +623,7 @@ impl StreamContext { // Get accumulated bytes from buffer and return match self.sse_buffer.as_mut() { Some(buffer) => { - let bytes = buffer.into_bytes(); + let bytes = buffer.to_bytes(); if !bytes.is_empty() { let content = String::from_utf8_lossy(&bytes); debug!( diff --git a/crates/prompt_gateway/src/http_context.rs b/crates/prompt_gateway/src/http_context.rs index fc66de12..1b6afbab 100644 --- a/crates/prompt_gateway/src/http_context.rs +++ b/crates/prompt_gateway/src/http_context.rs @@ -142,8 +142,7 @@ impl HttpContext for StreamContext { let last_user_prompt = match deserialized_body .messages .iter() - .filter(|msg| msg.role == USER_ROLE) - .last() + .rfind(|msg| msg.role == USER_ROLE) { Some(content) => content, None => { @@ -155,11 +154,8 @@ impl HttpContext for StreamContext { self.user_prompt = Some(last_user_prompt.clone()); // convert prompt targets to ChatCompletionTool - let tool_calls: Vec = self - .prompt_targets - .iter() - .map(|(_, pt)| pt.into()) - .collect(); + let tool_calls: Vec = + self.prompt_targets.values().map(|pt| pt.into()).collect(); let mut metadata = deserialized_body.metadata.clone(); diff --git a/crates/prompt_gateway/src/stream_context.rs b/crates/prompt_gateway/src/stream_context.rs index 1e9d507b..5f464930 100644 --- a/crates/prompt_gateway/src/stream_context.rs +++ b/crates/prompt_gateway/src/stream_context.rs @@ -376,21 +376,22 @@ impl StreamContext { // Parse arguments JSON string into HashMap // Note: convert from serde_json::Value to serde_yaml::Value for compatibility - let tool_params: Option> = match serde_json::from_str::>(tool_params_str) { - Ok(json_params) => { - let yaml_params: HashMap = json_params - .into_iter() - .filter_map(|(k, v)| { - serde_yaml::to_value(&v).ok().map(|yaml_v| (k, yaml_v)) - }) - .collect(); - Some(yaml_params) - }, - Err(e) => { - warn!("Failed to parse tool call arguments: {}", e); - None - } - }; + let tool_params: Option> = + match serde_json::from_str::>(tool_params_str) { + Ok(json_params) => { + let yaml_params: HashMap = json_params + .into_iter() + .filter_map(|(k, v)| { + serde_yaml::to_value(&v).ok().map(|yaml_v| (k, yaml_v)) + }) + .collect(); + Some(yaml_params) + } + Err(e) => { + warn!("Failed to parse tool call arguments: {}", e); + None + } + }; let endpoint_details = prompt_target.endpoint.as_ref().unwrap(); let endpoint_path: String = endpoint_details @@ -629,10 +630,10 @@ impl StreamContext { } }; - if system_prompt.is_some() { + if let Some(system_prompt_text) = system_prompt { let system_prompt_message = Message { role: SYSTEM_ROLE.to_string(), - content: Some(ContentType::Text(system_prompt.unwrap())), + content: Some(ContentType::Text(system_prompt_text)), model: None, tool_calls: None, tool_call_id: None,