cargo clippy (#660)

This commit is contained in:
Adil Hafeez 2025-12-25 21:08:37 -08:00 committed by GitHub
parent c75e7606f9
commit ca95ffb63d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
62 changed files with 1864 additions and 1187 deletions

View file

@ -13,19 +13,22 @@ repos:
name: cargo-fmt
language: system
types: [file, rust]
entry: bash -c "cd crates/llm_gateway && cargo fmt"
entry: bash -c "cd crates && cargo fmt --all -- --check"
pass_filenames: false
- id: cargo-clippy
name: cargo-clippy
language: system
types: [file, rust]
entry: bash -c "cd crates/llm_gateway && cargo clippy --all"
entry: bash -c "cd crates && cargo clippy --locked --offline --all-targets --all-features -- -D warnings || cargo clippy --locked --all-targets --all-features -- -D warnings"
pass_filenames: false
- id: cargo-test
name: cargo-test
language: system
types: [file, rust]
entry: bash -c "cd crates && cargo test --lib"
pass_filenames: false
- repo: https://github.com/psf/black
rev: 23.1.0

View file

@ -3,7 +3,7 @@ use std::time::{Instant, SystemTime};
use bytes::Bytes;
use common::consts::TRACE_PARENT_HEADER;
use common::traces::{SpanBuilder, SpanKind, parse_traceparent, generate_random_span_id};
use common::traces::{generate_random_span_id, parse_traceparent, SpanBuilder, SpanKind};
use hermesllm::apis::OpenAIMessage;
use hermesllm::clients::SupportedAPIsFromClient;
use hermesllm::providers::request::ProviderRequest;
@ -18,7 +18,7 @@ use super::agent_selector::{AgentSelectionError, AgentSelector};
use super::pipeline_processor::{PipelineError, PipelineProcessor};
use super::response_handler::ResponseHandler;
use crate::router::plano_orchestrator::OrchestratorService;
use crate::tracing::{OperationNameBuilder, operation_component, http};
use crate::tracing::{http, operation_component, OperationNameBuilder};
/// Main errors for agent chat completions
#[derive(Debug, thiserror::Error)]
@ -61,7 +61,6 @@ pub async fn agent_chat(
body,
}) = &err
{
warn!(
"Client error from agent '{}' (HTTP {}): {}",
agent, status, body
@ -77,8 +76,8 @@ pub async fn agent_chat(
let json_string = error_json.to_string();
let mut response = Response::new(ResponseHandler::create_full_body(json_string));
*response.status_mut() = hyper::StatusCode::from_u16(*status)
.unwrap_or(hyper::StatusCode::BAD_REQUEST);
*response.status_mut() =
hyper::StatusCode::from_u16(*status).unwrap_or(hyper::StatusCode::BAD_REQUEST);
response.headers_mut().insert(
hyper::header::CONTENT_TYPE,
"application/json".parse().unwrap(),
@ -234,8 +233,18 @@ async fn handle_agent_chat(
.with_attribute(http::TARGET, "/agents/select")
.with_attribute("selection.listener", listener.name.clone())
.with_attribute("selection.agent_count", selected_agents.len().to_string())
.with_attribute("selection.agents", selected_agents.iter().map(|a| a.id.as_str()).collect::<Vec<_>>().join(","))
.with_attribute("duration_ms", format!("{:.2}", selection_elapsed.as_secs_f64() * 1000.0));
.with_attribute(
"selection.agents",
selected_agents
.iter()
.map(|a| a.id.as_str())
.collect::<Vec<_>>()
.join(","),
)
.with_attribute(
"duration_ms",
format!("{:.2}", selection_elapsed.as_secs_f64() * 1000.0),
);
if !trace_id.is_empty() {
selection_span_builder = selection_span_builder.with_trace_id(trace_id.clone());
@ -318,8 +327,14 @@ async fn handle_agent_chat(
.with_attribute(http::METHOD, "POST")
.with_attribute(http::TARGET, full_path)
.with_attribute("agent.name", agent_name.clone())
.with_attribute("agent.sequence", format!("{}/{}", agent_index + 1, agent_count))
.with_attribute("duration_ms", format!("{:.2}", agent_elapsed.as_secs_f64() * 1000.0));
.with_attribute(
"agent.sequence",
format!("{}/{}", agent_index + 1, agent_count),
)
.with_attribute(
"duration_ms",
format!("{:.2}", agent_elapsed.as_secs_f64() * 1000.0),
);
if !trace_id.is_empty() {
span_builder = span_builder.with_trace_id(trace_id.clone());
@ -333,7 +348,10 @@ async fn handle_agent_chat(
// If this is the last agent, return the streaming response
if is_last_agent {
info!("Completed agent chain, returning response from last agent: {}", agent_name);
info!(
"Completed agent chain, returning response from last agent: {}",
agent_name
);
return response_handler
.create_streaming_response(llm_response)
.await
@ -341,7 +359,10 @@ async fn handle_agent_chat(
}
// For intermediate agents, collect the full response and pass to next agent
debug!("Collecting response from intermediate agent: {}", agent_name);
debug!(
"Collecting response from intermediate agent: {}",
agent_name
);
let response_text = response_handler.collect_full_response(llm_response).await?;
info!(
@ -364,7 +385,6 @@ async fn handle_agent_chat(
});
current_messages.push(last_message);
}
// This should never be reached since we return in the last agent iteration

View file

@ -2,7 +2,7 @@ use std::collections::HashMap;
use std::sync::Arc;
use common::configuration::{
Agent, AgentFilterChain, Listener, AgentUsagePreference, OrchestrationPreference,
Agent, AgentFilterChain, AgentUsagePreference, Listener, OrchestrationPreference,
};
use hermesllm::apis::openai::Message;
use tracing::{debug, warn};

View file

@ -1,20 +1,18 @@
use bytes::Bytes;
use eventsource_stream::Eventsource;
use futures::StreamExt;
use hermesllm::apis::openai::{
ChatCompletionsRequest, ChatCompletionsResponse, Choice, FinishReason, FunctionCall, Message,
MessageContent, ResponseMessage, Role, Tool, ToolCall, Usage,
};
use http_body_util::{combinators::BoxBody, BodyExt, Full};
use hyper::body::Incoming;
use hyper::{Request, Response, StatusCode};
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use std::collections::HashMap;
use thiserror::Error;
use tracing::{info, error};
use futures::StreamExt;
use bytes::Bytes;
use http_body_util::{combinators::BoxBody, BodyExt, Full};
use hyper::body::Incoming;
use hyper::{Request, Response, StatusCode};
use eventsource_stream::Eventsource;
use tracing::{error, info};
// ============================================================================
// CONSTANTS FOR HALLUCINATION DETECTION
@ -273,17 +271,14 @@ impl ArchFunctionHandler {
let mut stack: Vec<char> = Vec::new();
let mut fixed_str = String::new();
let matching_bracket: HashMap<char, char> =
[(')', '('), ('}', '{'), (']', '[')]
.iter()
.cloned()
.collect();
let opening_bracket: HashMap<char, char> = matching_bracket
let matching_bracket: HashMap<char, char> = [(')', '('), ('}', '{'), (']', '[')]
.iter()
.map(|(k, v)| (*v, *k))
.cloned()
.collect();
let opening_bracket: HashMap<char, char> =
matching_bracket.iter().map(|(k, v)| (*v, *k)).collect();
for ch in json_str.chars() {
if ch == '{' || ch == '[' || ch == '(' {
stack.push(ch);
@ -332,12 +327,18 @@ impl ArchFunctionHandler {
// Remove markdown code blocks
let mut content = content.trim().to_string();
if content.starts_with("```") && content.ends_with("```") {
content = content.trim_start_matches("```").trim_end_matches("```").to_string();
content = content
.trim_start_matches("```")
.trim_end_matches("```")
.to_string();
if content.starts_with("json") {
content = content.trim_start_matches("json").to_string();
}
// Trim again after removing code blocks to eliminate internal whitespace
content = content.trim_start_matches(r"\n").trim_end_matches(r"\n").to_string();
content = content
.trim_start_matches(r"\n")
.trim_end_matches(r"\n")
.to_string();
content = content.trim().to_string();
// Unescape the quotes: \" -> "
// The model sometimes returns escaped JSON inside markdown blocks
@ -453,12 +454,12 @@ impl ArchFunctionHandler {
/// Helper method to check if a value matches the expected type
fn check_value_type(&self, value: &Value, target_type: &str) -> bool {
match target_type {
"int" | "integer" => value.is_i64() || value.is_u64(),
"int" | "integer" => value.is_i64() || value.is_u64(),
"float" | "number" => value.is_f64() || value.is_i64() || value.is_u64(),
"bool" | "boolean" => value.is_boolean(),
"str" | "string" => value.is_string(),
"list" | "array" => value.is_array(),
"dict" | "object" => value.is_object(),
"bool" | "boolean" => value.is_boolean(),
"str" | "string" => value.is_string(),
"list" | "array" => value.is_array(),
"dict" | "object" => value.is_object(),
_ => true,
}
}
@ -505,15 +506,19 @@ impl ArchFunctionHandler {
let func_name = &tool_call.function.name;
// Parse arguments as JSON
let func_args: HashMap<String, Value> = match serde_json::from_str(&tool_call.function.arguments) {
Ok(args) => args,
Err(e) => {
verification.is_valid = false;
verification.invalid_tool_call = Some(tool_call.clone());
verification.error_message = format!("Failed to parse arguments for function '{}': {}", func_name, e);
break;
}
};
let func_args: HashMap<String, Value> =
match serde_json::from_str(&tool_call.function.arguments) {
Ok(args) => args,
Err(e) => {
verification.is_valid = false;
verification.invalid_tool_call = Some(tool_call.clone());
verification.error_message = format!(
"Failed to parse arguments for function '{}': {}",
func_name, e
);
break;
}
};
// Check if function is available
if let Some(function_params) = functions.get(func_name) {
@ -541,14 +546,23 @@ impl ArchFunctionHandler {
if let Some(properties_obj) = properties.as_object() {
for (param_name, param_value) in &func_args {
if let Some(param_schema) = properties_obj.get(param_name) {
if let Some(target_type) = param_schema.get("type").and_then(|v| v.as_str()) {
if self.config.support_data_types.contains(&target_type.to_string()) {
if let Some(target_type) =
param_schema.get("type").and_then(|v| v.as_str())
{
if self
.config
.support_data_types
.contains(&target_type.to_string())
{
// Validate data type using helper method
match self.validate_or_convert_parameter(param_value, target_type) {
match self
.validate_or_convert_parameter(param_value, target_type)
{
Ok(is_valid) => {
if !is_valid {
verification.is_valid = false;
verification.invalid_tool_call = Some(tool_call.clone());
verification.invalid_tool_call =
Some(tool_call.clone());
verification.error_message = format!(
"Parameter `{}` is expected to have the data type `{}`, got incompatible type.",
param_name, target_type
@ -558,7 +572,8 @@ impl ArchFunctionHandler {
}
Err(_) => {
verification.is_valid = false;
verification.invalid_tool_call = Some(tool_call.clone());
verification.invalid_tool_call =
Some(tool_call.clone());
verification.error_message = format!(
"Parameter `{}` is expected to have the data type `{}`, got incompatible type.",
param_name, target_type
@ -569,7 +584,10 @@ impl ArchFunctionHandler {
} else {
verification.is_valid = false;
verification.invalid_tool_call = Some(tool_call.clone());
verification.error_message = format!("Data type `{}` is not supported.", target_type);
verification.error_message = format!(
"Data type `{}` is not supported.",
target_type
);
break;
}
}
@ -598,11 +616,8 @@ impl ArchFunctionHandler {
/// Formats the system prompt with tools
pub fn format_system_prompt(&self, tools: &[Tool]) -> Result<String> {
let tools_str = self.convert_tools(tools)?;
let system_prompt = self
.config
.task_prompt
.replace("{tools}", &tools_str)
+ &self.config.format_prompt;
let system_prompt =
self.config.task_prompt.replace("{tools}", &tools_str) + &self.config.format_prompt;
Ok(system_prompt)
}
@ -665,15 +680,22 @@ impl ArchFunctionHandler {
// Strip markdown code blocks
if tool_call_msg.starts_with("```") && tool_call_msg.ends_with("```") {
tool_call_msg = tool_call_msg.trim_start_matches("```").trim_end_matches("```").trim().to_string();
tool_call_msg = tool_call_msg
.trim_start_matches("```")
.trim_end_matches("```")
.trim()
.to_string();
if tool_call_msg.starts_with("json") {
tool_call_msg = tool_call_msg.trim_start_matches("json").trim().to_string();
tool_call_msg =
tool_call_msg.trim_start_matches("json").trim().to_string();
}
}
// Extract function name
if let Ok(parsed) = serde_json::from_str::<Value>(&tool_call_msg) {
if let Some(tool_calls_arr) = parsed.get("tool_calls").and_then(|v| v.as_array()) {
if let Some(tool_calls_arr) =
parsed.get("tool_calls").and_then(|v| v.as_array())
{
if let Some(first_tool_call) = tool_calls_arr.first() {
let func_name = first_tool_call
.get("name")
@ -685,8 +707,10 @@ impl ArchFunctionHandler {
"result": content,
});
content = format!("<tool_response>\n{}\n</tool_response>",
serde_json::to_string(&tool_response)?);
content = format!(
"<tool_response>\n{}\n</tool_response>",
serde_json::to_string(&tool_response)?
);
}
}
}
@ -717,7 +741,7 @@ impl ArchFunctionHandler {
if let Some(instruction) = extra_instruction {
if let Some(last) = processed_messages.last_mut() {
if let MessageContent::Text(content) = &mut last.content {
content.push_str("\n");
content.push('\n');
content.push_str(instruction);
}
}
@ -750,13 +774,11 @@ impl ArchFunctionHandler {
for i in (conversation_idx..messages.len()).rev() {
if let MessageContent::Text(content) = &messages[i].content {
num_tokens += content.len() / 4;
if num_tokens >= max_tokens {
if messages[i].role == Role::User {
// Set message_idx to current position and break
// This matches Python's behavior where message_idx is set before break
message_idx = i;
break;
}
if num_tokens >= max_tokens && messages[i].role == Role::User {
// Set message_idx to current position and break
// This matches Python's behavior where message_idx is set before break
message_idx = i;
break;
}
}
// Only update message_idx if we haven't hit the token limit yet
@ -789,7 +811,11 @@ impl ArchFunctionHandler {
}
/// Helper to create a request with VLLM-specific parameters
fn create_request_with_extra_body(&self, messages: Vec<Message>, stream: bool) -> ChatCompletionsRequest {
fn create_request_with_extra_body(
&self,
messages: Vec<Message>,
stream: bool,
) -> ChatCompletionsRequest {
ChatCompletionsRequest {
model: self.model_name.clone(),
messages,
@ -813,24 +839,38 @@ impl ArchFunctionHandler {
}
/// Makes a streaming request and returns the SSE event stream
async fn make_streaming_request(&self, request: ChatCompletionsRequest) -> Result<std::pin::Pin<Box<dyn futures::Stream<Item = std::result::Result<Value, String>> + Send>>> {
let request_body = serde_json::to_string(&request)
.map_err(|e| FunctionCallingError::InvalidModelResponse(format!("Failed to serialize request: {}", e)))?;
async fn make_streaming_request(
&self,
request: ChatCompletionsRequest,
) -> Result<
std::pin::Pin<Box<dyn futures::Stream<Item = std::result::Result<Value, String>> + Send>>,
> {
let request_body = serde_json::to_string(&request).map_err(|e| {
FunctionCallingError::InvalidModelResponse(format!(
"Failed to serialize request: {}",
e
))
})?;
let response = self.http_client
let response = self
.http_client
.post(&self.endpoint_url)
.header("Content-Type", "application/json")
.body(request_body)
.send()
.await
.map_err(|e| FunctionCallingError::HttpError(e))?;
.map_err(FunctionCallingError::HttpError)?;
if !response.status().is_success() {
let status = response.status();
let error_text = response.text().await.unwrap_or_else(|_| "Unknown error".to_string());
return Err(FunctionCallingError::InvalidModelResponse(
format!("HTTP error {}: {}", status, error_text)
));
let error_text = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
return Err(FunctionCallingError::InvalidModelResponse(format!(
"HTTP error {}: {}",
status, error_text
)));
}
// Parse SSE stream
@ -856,38 +896,51 @@ impl ArchFunctionHandler {
}
/// Makes a non-streaming request and returns the response
async fn make_non_streaming_request(&self, request: ChatCompletionsRequest) -> Result<ChatCompletionsResponse> {
let request_body = serde_json::to_string(&request)
.map_err(|e| FunctionCallingError::InvalidModelResponse(format!("Failed to serialize request: {}", e)))?;
async fn make_non_streaming_request(
&self,
request: ChatCompletionsRequest,
) -> Result<ChatCompletionsResponse> {
let request_body = serde_json::to_string(&request).map_err(|e| {
FunctionCallingError::InvalidModelResponse(format!(
"Failed to serialize request: {}",
e
))
})?;
let response = self.http_client
let response = self
.http_client
.post(&self.endpoint_url)
.header("Content-Type", "application/json")
.body(request_body)
.send()
.await
.map_err(|e| FunctionCallingError::HttpError(e))?;
.map_err(FunctionCallingError::HttpError)?;
if !response.status().is_success() {
let status = response.status();
let error_text = response.text().await.unwrap_or_else(|_| "Unknown error".to_string());
return Err(FunctionCallingError::InvalidModelResponse(
format!("HTTP error {}: {}", status, error_text)
));
let error_text = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
return Err(FunctionCallingError::InvalidModelResponse(format!(
"HTTP error {}: {}",
status, error_text
)));
}
let response_text = response.text().await
.map_err(|e| FunctionCallingError::HttpError(e))?;
let response_text = response
.text()
.await
.map_err(FunctionCallingError::HttpError)?;
serde_json::from_str(&response_text)
.map_err(|e| FunctionCallingError::JsonParseError(e))
serde_json::from_str(&response_text).map_err(FunctionCallingError::JsonParseError)
}
pub async fn function_calling_chat(
&self,
request: ChatCompletionsRequest,
) -> Result<ChatCompletionsResponse> {
use tracing::{info, error};
use tracing::{error, info};
info!("[Arch-Function] - ChatCompletion");
@ -899,10 +952,14 @@ impl ArchFunctionHandler {
request.metadata.as_ref(),
)?;
info!("[request to arch-fc]: model: {}, messages count: {}",
self.model_name, messages.len());
info!(
"[request to arch-fc]: model: {}, messages count: {}",
self.model_name,
messages.len()
);
let use_agent_orchestrator = request.metadata
let use_agent_orchestrator = request
.metadata
.as_ref()
.and_then(|m| m.get("use_agent_orchestrator"))
.and_then(|v| v.as_bool())
@ -918,89 +975,95 @@ impl ArchFunctionHandler {
if use_agent_orchestrator {
while let Some(chunk_result) = stream.next().await {
let chunk = chunk_result.map_err(|e| FunctionCallingError::InvalidModelResponse(e))?;
let chunk = chunk_result.map_err(FunctionCallingError::InvalidModelResponse)?;
// Extract content from JSON response
if let Some(choices) = chunk.get("choices").and_then(|v| v.as_array()) {
if let Some(choice) = choices.first() {
if let Some(content) = choice.get("delta")
if let Some(content) = choice
.get("delta")
.and_then(|d| d.get("content"))
.and_then(|c| c.as_str()) {
.and_then(|c| c.as_str())
{
model_response.push_str(content);
}
}
}
}
info!("[Agent Orchestrator]: response received");
} else {
if let Some(tools) = request.tools.as_ref() {
let mut hallucination_state = HallucinationState::new(tools);
let mut has_tool_calls = None;
let mut has_hallucination = false;
} else if let Some(tools) = request.tools.as_ref() {
let mut hallucination_state = HallucinationState::new(tools);
let mut has_tool_calls = None;
let mut has_hallucination = false;
while let Some(chunk_result) = stream.next().await {
let chunk = chunk_result.map_err(|e| FunctionCallingError::InvalidModelResponse(e))?;
while let Some(chunk_result) = stream.next().await {
let chunk = chunk_result.map_err(FunctionCallingError::InvalidModelResponse)?;
// Extract content and logprobs from JSON response
if let Some(choices) = chunk.get("choices").and_then(|v| v.as_array()) {
if let Some(choice) = choices.first() {
if let Some(content) = choice.get("delta")
.and_then(|d| d.get("content"))
.and_then(|c| c.as_str()) {
// Extract content and logprobs from JSON response
if let Some(choices) = chunk.get("choices").and_then(|v| v.as_array()) {
if let Some(choice) = choices.first() {
if let Some(content) = choice
.get("delta")
.and_then(|d| d.get("content"))
.and_then(|c| c.as_str())
{
// Extract logprobs
let logprobs: Vec<f64> = choice
.get("logprobs")
.and_then(|lp| lp.get("content"))
.and_then(|c| c.as_array())
.and_then(|arr| arr.first())
.and_then(|token| token.get("top_logprobs"))
.and_then(|tlp| tlp.as_array())
.map(|arr| {
arr.iter()
.filter_map(|v| v.get("logprob").and_then(|lp| lp.as_f64()))
.collect()
})
.unwrap_or_default();
// Extract logprobs
let logprobs: Vec<f64> = choice.get("logprobs")
.and_then(|lp| lp.get("content"))
.and_then(|c| c.as_array())
.and_then(|arr| arr.first())
.and_then(|token| token.get("top_logprobs"))
.and_then(|tlp| tlp.as_array())
.map(|arr| {
arr.iter()
.filter_map(|v| v.get("logprob").and_then(|lp| lp.as_f64()))
.collect()
})
.unwrap_or_default();
if hallucination_state
.append_and_check_token_hallucination(content.to_string(), logprobs)
{
has_hallucination = true;
break;
}
if hallucination_state.append_and_check_token_hallucination(content.to_string(), logprobs) {
has_hallucination = true;
break;
}
if hallucination_state.tokens.len() > 5 && has_tool_calls.is_none() {
let collected_content = hallucination_state.tokens.join("");
has_tool_calls = Some(collected_content.contains("tool_calls"));
}
if hallucination_state.tokens.len() > 5 && has_tool_calls.is_none() {
let collected_content = hallucination_state.tokens.join("");
has_tool_calls = Some(collected_content.contains("tool_calls"));
}
}
}
}
}
if has_tool_calls == Some(true) && has_hallucination {
info!("[Hallucination]: {}", hallucination_state.error_message);
if has_tool_calls == Some(true) && has_hallucination {
info!("[Hallucination]: {}", hallucination_state.error_message);
let clarify_messages = self.prefill_message(messages.clone(), &self.clarify_prefix);
let clarify_request = self.create_request_with_extra_body(clarify_messages, false);
let clarify_messages = self.prefill_message(messages.clone(), &self.clarify_prefix);
let clarify_request = self.create_request_with_extra_body(clarify_messages, false);
let retry_response = self.make_non_streaming_request(clarify_request).await?;
let retry_response = self.make_non_streaming_request(clarify_request).await?;
if let Some(choice) = retry_response.choices.first() {
if let Some(content) = &choice.message.content {
model_response = content.clone();
}
if let Some(choice) = retry_response.choices.first() {
if let Some(content) = &choice.message.content {
model_response = content.clone();
}
} else {
model_response = hallucination_state.tokens.join("");
}
} else {
while let Some(chunk_result) = stream.next().await {
let chunk = chunk_result.map_err(|e| FunctionCallingError::InvalidModelResponse(e))?;
if let Some(choices) = chunk.get("choices").and_then(|v| v.as_array()) {
if let Some(choice) = choices.first() {
if let Some(content) = choice.get("delta")
.and_then(|d| d.get("content"))
.and_then(|c| c.as_str()) {
model_response.push_str(content);
}
model_response = hallucination_state.tokens.join("");
}
} else {
while let Some(chunk_result) = stream.next().await {
let chunk = chunk_result.map_err(FunctionCallingError::InvalidModelResponse)?;
if let Some(choices) = chunk.get("choices").and_then(|v| v.as_array()) {
if let Some(choice) = choices.first() {
if let Some(content) = choice
.get("delta")
.and_then(|d| d.get("content"))
.and_then(|c| c.as_str())
{
model_response.push_str(content);
}
}
}
@ -1009,10 +1072,17 @@ impl ArchFunctionHandler {
let response_dict = self.parse_model_response(&model_response);
info!("[arch-fc]: raw model response: {}", response_dict.raw_response);
info!(
"[arch-fc]: raw model response: {}",
response_dict.raw_response
);
// General model response (no intent matched - should route to default target)
let model_message = if response_dict.response.as_ref().map_or(false, |s| !s.is_empty()) {
let model_message = if response_dict
.response
.as_ref()
.is_some_and(|s| !s.is_empty())
{
// When arch-fc returns a "response" field, it means no intent was matched
// Return empty content and empty tool_calls so prompt_gateway routes to default target
ResponseMessage {
@ -1053,8 +1123,11 @@ impl ArchFunctionHandler {
let verification = self.verify_tool_calls(tools, &response_dict.tool_calls);
if verification.is_valid {
info!("[Tool calls]: {:?}",
response_dict.tool_calls.iter()
info!(
"[Tool calls]: {:?}",
response_dict
.tool_calls
.iter()
.map(|tc| &tc.function)
.collect::<Vec<_>>()
);
@ -1092,8 +1165,11 @@ impl ArchFunctionHandler {
}
}
} else {
info!("[Tool calls]: {:?}",
response_dict.tool_calls.iter()
info!(
"[Tool calls]: {:?}",
response_dict
.tool_calls
.iter()
.map(|tc| &tc.function)
.collect::<Vec<_>>()
);
@ -1108,7 +1184,10 @@ impl ArchFunctionHandler {
}
}
} else {
error!("Invalid tool calls in response: {}", response_dict.error_message);
error!(
"Invalid tool calls in response: {}",
response_dict.error_message
);
ResponseMessage {
role: Role::Assistant,
content: Some(String::new()),
@ -1243,7 +1322,6 @@ pub async fn function_calling_chat_handler(
req: Request<Incoming>,
llm_provider_url: String,
) -> std::result::Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
use hermesllm::apis::openai::ChatCompletionsRequest;
let whole_body = req.collect().await?.to_bytes();
@ -1255,10 +1333,13 @@ pub async fn function_calling_chat_handler(
let mut response = Response::new(full(
serde_json::json!({
"error": format!("Invalid request body: {}", e)
}).to_string()
})
.to_string(),
));
*response.status_mut() = StatusCode::BAD_REQUEST;
response.headers_mut().insert("Content-Type", "application/json".parse().unwrap());
response
.headers_mut()
.insert("Content-Type", "application/json".parse().unwrap());
return Ok(response);
}
};
@ -1271,24 +1352,31 @@ pub async fn function_calling_chat_handler(
// Parse as ChatCompletionsRequest
let chat_request: ChatCompletionsRequest = match serde_json::from_value(body_json) {
Ok(req) => {
info!("[request body]: {}", serde_json::to_string(&req).unwrap_or_default());
info!(
"[request body]: {}",
serde_json::to_string(&req).unwrap_or_default()
);
req
},
}
Err(e) => {
error!("Failed to parse request body: {}", e);
let mut response = Response::new(full(
serde_json::json!({
"error": format!("Invalid request body: {}", e)
}).to_string()
})
.to_string(),
));
*response.status_mut() = StatusCode::BAD_REQUEST;
response.headers_mut().insert("Content-Type", "application/json".parse().unwrap());
response
.headers_mut()
.insert("Content-Type", "application/json".parse().unwrap());
return Ok(response);
}
};
// Determine which handler to use based on metadata
let use_agent_orchestrator = chat_request.metadata
let use_agent_orchestrator = chat_request
.metadata
.as_ref()
.and_then(|m| m.get("use_agent_orchestrator"))
.and_then(|v| v.as_bool())
@ -1309,7 +1397,10 @@ pub async fn function_calling_chat_handler(
ARCH_FUNCTION_MODEL_NAME.to_string(),
llm_provider_url.clone(),
);
handler.function_handler.function_calling_chat(chat_request).await
handler
.function_handler
.function_calling_chat(chat_request)
.await
} else {
let handler = ArchFunctionHandler::new(
ARCH_FUNCTION_MODEL_NAME.to_string(),
@ -1328,7 +1419,9 @@ pub async fn function_calling_chat_handler(
let mut response = Response::new(full(response_json));
*response.status_mut() = StatusCode::OK;
response.headers_mut().insert("Content-Type", "application/json".parse().unwrap());
response
.headers_mut()
.insert("Content-Type", "application/json".parse().unwrap());
Ok(response)
}
@ -1341,13 +1434,14 @@ pub async fn function_calling_chat_handler(
let mut response = Response::new(full(error_response.to_string()));
*response.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
response.headers_mut().insert("Content-Type", "application/json".parse().unwrap());
response
.headers_mut()
.insert("Content-Type", "application/json".parse().unwrap());
Ok(response)
}
}
}
// ============================================================================
// TESTS
// ============================================================================
@ -1370,10 +1464,13 @@ mod tests {
assert!(config.task_prompt.contains("</tools>\\n\\n"));
// Format prompt should contain literal escaped newlines and proper JSON examples
assert!(config.format_prompt.contains("\\n\\nBased on your analysis"));
assert!(config.format_prompt.contains(r#"{\"response\": \"Your response text here\"}"#));
assert!(config
.format_prompt
.contains("\\n\\nBased on your analysis"));
assert!(config
.format_prompt
.contains(r#"{\"response\": \"Your response text here\"}"#));
assert!(config.format_prompt.contains(r#"{\"tool_calls\": [{"#));
}
#[test]
@ -1384,7 +1481,11 @@ mod tests {
#[test]
fn test_fix_json_string_valid() {
let handler = ArchFunctionHandler::new("test-model".to_string(), ArchFunctionConfig::default(), "http://localhost:8000".to_string());
let handler = ArchFunctionHandler::new(
"test-model".to_string(),
ArchFunctionConfig::default(),
"http://localhost:8000".to_string(),
);
let json_str = r#"{"name": "test", "value": 123}"#;
let result = handler.fix_json_string(json_str);
assert!(result.is_ok());
@ -1392,7 +1493,11 @@ mod tests {
#[test]
fn test_fix_json_string_missing_bracket() {
let handler = ArchFunctionHandler::new("test-model".to_string(), ArchFunctionConfig::default(), "http://localhost:8000".to_string());
let handler = ArchFunctionHandler::new(
"test-model".to_string(),
ArchFunctionConfig::default(),
"http://localhost:8000".to_string(),
);
let json_str = r#"{"name": "test", "value": 123"#;
let result = handler.fix_json_string(json_str);
assert!(result.is_ok());
@ -1402,8 +1507,13 @@ mod tests {
#[test]
fn test_parse_model_response_with_tool_calls() {
let handler = ArchFunctionHandler::new("test-model".to_string(), ArchFunctionConfig::default(), "http://localhost:8000".to_string());
let content = r#"{"tool_calls": [{"name": "get_weather", "arguments": {"location": "NYC"}}]}"#;
let handler = ArchFunctionHandler::new(
"test-model".to_string(),
ArchFunctionConfig::default(),
"http://localhost:8000".to_string(),
);
let content =
r#"{"tool_calls": [{"name": "get_weather", "arguments": {"location": "NYC"}}]}"#;
let result = handler.parse_model_response(content);
assert!(result.is_valid);
@ -1413,8 +1523,13 @@ mod tests {
#[test]
fn test_parse_model_response_with_clarification() {
let handler = ArchFunctionHandler::new("test-model".to_string(), ArchFunctionConfig::default(), "http://localhost:8000".to_string());
let content = r#"{"required_functions": ["get_weather"], "clarification": "What location?"}"#;
let handler = ArchFunctionHandler::new(
"test-model".to_string(),
ArchFunctionConfig::default(),
"http://localhost:8000".to_string(),
);
let content =
r#"{"required_functions": ["get_weather"], "clarification": "What location?"}"#;
let result = handler.parse_model_response(content);
assert!(result.is_valid);
@ -1424,7 +1539,11 @@ mod tests {
#[test]
fn test_convert_data_type_int_to_float() {
let handler = ArchFunctionHandler::new("test-model".to_string(), ArchFunctionConfig::default(), "http://localhost:8000".to_string());
let handler = ArchFunctionHandler::new(
"test-model".to_string(),
ArchFunctionConfig::default(),
"http://localhost:8000".to_string(),
);
let value = json!(42);
let result = handler.convert_data_type(&value, "float");
assert!(result.is_ok());
@ -1504,13 +1623,12 @@ pub fn check_threshold(
}
/// Checks if a parameter is required in the function description
pub fn is_parameter_required(
function_description: &Value,
parameter_name: &str,
) -> bool {
pub fn is_parameter_required(function_description: &Value, parameter_name: &str) -> bool {
if let Some(required) = function_description.get("required") {
if let Some(required_arr) = required.as_array() {
return required_arr.iter().any(|v| v.as_str() == Some(parameter_name));
return required_arr
.iter()
.any(|v| v.as_str() == Some(parameter_name));
}
}
false
@ -1559,12 +1677,7 @@ impl HallucinationState {
pub fn new(functions: &[Tool]) -> Self {
let function_properties: HashMap<String, Value> = functions
.iter()
.map(|tool| {
(
tool.function.name.clone(),
tool.function.parameters.clone(),
)
})
.map(|tool| (tool.function.name.clone(), tool.function.parameters.clone()))
.collect();
Self {
@ -1620,7 +1733,10 @@ impl HallucinationState {
// Function name extraction logic
if self.state.as_deref() == Some("function_name") {
if !FUNC_NAME_END_TOKEN.iter().any(|&t| self.tokens.last().map_or(false, |tok| tok == t)) {
if !FUNC_NAME_END_TOKEN
.iter()
.any(|&t| self.tokens.last().is_some_and(|tok| tok == t))
{
self.mask.push(MaskToken::FunctionName);
} else {
self.state = None;
@ -1629,34 +1745,51 @@ impl HallucinationState {
}
// Check for function name start
if FUNC_NAME_START_PATTERN.iter().any(|&p| content.ends_with(p)) {
if FUNC_NAME_START_PATTERN
.iter()
.any(|&p| content.ends_with(p))
{
self.state = Some("function_name".to_string());
}
// Parameter name extraction logic
if self.state.as_deref() == Some("parameter_name")
&& !PARAMETER_NAME_END_TOKENS.iter().any(|&t| content.ends_with(t)) {
&& !PARAMETER_NAME_END_TOKENS
.iter()
.any(|&t| content.ends_with(t))
{
self.mask.push(MaskToken::ParameterName);
} else if self.state.as_deref() == Some("parameter_name")
&& PARAMETER_NAME_END_TOKENS.iter().any(|&t| content.ends_with(t)) {
&& PARAMETER_NAME_END_TOKENS
.iter()
.any(|&t| content.ends_with(t))
{
self.state = None;
self.parameter_name_done = true;
self.get_parameter_name();
} else if self.parameter_name_done
&& !self.open_bracket
&& PARAMETER_NAME_START_PATTERN.iter().any(|&p| content.ends_with(p)) {
&& PARAMETER_NAME_START_PATTERN
.iter()
.any(|&p| content.ends_with(p))
{
self.state = Some("parameter_name".to_string());
}
// First parameter value start
if FIRST_PARAM_NAME_START_PATTERN.iter().any(|&p| content.ends_with(p)) {
if FIRST_PARAM_NAME_START_PATTERN
.iter()
.any(|&p| content.ends_with(p))
{
self.state = Some("parameter_name".to_string());
}
// Parameter value extraction logic
if self.state.as_deref() == Some("parameter_value")
&& !PARAMETER_VALUE_END_TOKEN.iter().any(|&t| content.ends_with(t)) {
&& !PARAMETER_VALUE_END_TOKEN
.iter()
.any(|&t| content.ends_with(t))
{
// Check for brackets
if let Some(last_token) = self.tokens.last() {
let open_brackets: Vec<char> = last_token
@ -1694,8 +1827,11 @@ impl HallucinationState {
&& self.mask[self.mask.len() - 2] != MaskToken::ParameterValue
&& !self.parameter_name.is_empty()
{
let last_param = self.parameter_name[self.parameter_name.len() - 1].clone();
if let Some(func_props) = self.function_properties.get(&self.function_name) {
let last_param =
self.parameter_name[self.parameter_name.len() - 1].clone();
if let Some(func_props) =
self.function_properties.get(&self.function_name)
{
if is_parameter_required(func_props, &last_param)
&& !is_parameter_property(func_props, &last_param, "enum")
&& !self.check_parameter_name.contains_key(&last_param)
@ -1718,10 +1854,16 @@ impl HallucinationState {
}
} else if self.state.as_deref() == Some("parameter_value")
&& !self.open_bracket
&& PARAMETER_VALUE_END_TOKEN.iter().any(|&t| content.ends_with(t)) {
&& PARAMETER_VALUE_END_TOKEN
.iter()
.any(|&t| content.ends_with(t))
{
self.state = None;
} else if self.parameter_name_done
&& PARAMETER_VALUE_START_PATTERN.iter().any(|&p| content.ends_with(p)) {
&& PARAMETER_VALUE_START_PATTERN
.iter()
.any(|&p| content.ends_with(p))
{
self.state = Some("parameter_value".to_string());
}
@ -1848,18 +1990,18 @@ mod hallucination_tests {
let handler = ArchFunctionHandler::new(
"test-model".to_string(),
ArchFunctionConfig::default(),
"http://localhost:8000".to_string()
"http://localhost:8000".to_string(),
);
// Test integer types
assert!(handler.check_value_type(&json!(42), "integer"));
assert!(handler.check_value_type(&json!(42), "int"));
assert!(!handler.check_value_type(&json!(3.14), "integer"));
assert!(!handler.check_value_type(&json!(3.15), "integer"));
// Test number types (accepts both int and float)
assert!(handler.check_value_type(&json!(3.14), "number"));
assert!(handler.check_value_type(&json!(3.15), "number"));
assert!(handler.check_value_type(&json!(42), "number"));
assert!(handler.check_value_type(&json!(3.14), "float"));
assert!(handler.check_value_type(&json!(3.15), "float"));
// Test boolean
assert!(handler.check_value_type(&json!(true), "boolean"));
@ -1890,12 +2032,16 @@ mod hallucination_tests {
let handler = ArchFunctionHandler::new(
"test-model".to_string(),
ArchFunctionConfig::default(),
"http://localhost:8000".to_string()
"http://localhost:8000".to_string(),
);
// Test valid type - no conversion needed
assert!(handler.validate_or_convert_parameter(&json!(42), "integer").unwrap());
assert!(handler.validate_or_convert_parameter(&json!("hello"), "string").unwrap());
assert!(handler
.validate_or_convert_parameter(&json!(42), "integer")
.unwrap());
assert!(handler
.validate_or_convert_parameter(&json!("hello"), "string")
.unwrap());
// Test integer to float conversion (convert_data_type supports this)
let result = handler.validate_or_convert_parameter(&json!(42), "float");
@ -1910,8 +2056,12 @@ mod hallucination_tests {
assert!(!result.unwrap());
// Test number accepting both int and float
assert!(handler.validate_or_convert_parameter(&json!(42), "number").unwrap());
assert!(handler.validate_or_convert_parameter(&json!(3.14), "number").unwrap());
assert!(handler
.validate_or_convert_parameter(&json!(42), "number")
.unwrap());
assert!(handler
.validate_or_convert_parameter(&json!(3.15), "number")
.unwrap());
}
#[test]

View file

@ -14,7 +14,7 @@ use crate::router::plano_orchestrator::OrchestratorService;
/// 2. PipelineProcessor - executes the agent pipeline
/// 3. ResponseHandler - handles response streaming
#[cfg(test)]
mod integration_tests {
mod tests {
use super::*;
use common::configuration::{Agent, AgentFilterChain, Listener};
@ -62,7 +62,10 @@ mod integration_tests {
let agent_pipeline = AgentFilterChain {
id: "terminal-agent".to_string(),
filter_chain: Some(vec!["filter-agent".to_string(), "terminal-agent".to_string()]),
filter_chain: Some(vec![
"filter-agent".to_string(),
"terminal-agent".to_string(),
]),
description: Some("Test pipeline".to_string()),
default: Some(true),
};

View file

@ -2,48 +2,48 @@ use serde::{Deserialize, Serialize};
use std::collections::HashMap;
pub const JSON_RPC_VERSION: &str = "2.0";
pub const TOOL_CALL_METHOD : &str = "tools/call";
pub const TOOL_CALL_METHOD: &str = "tools/call";
pub const MCP_INITIALIZE: &str = "initialize";
pub const MCP_INITIALIZE_NOTIFICATION: &str = "notifications/initialized";
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum JsonRpcId {
String(String),
Number(u64),
String(String),
Number(u64),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JsonRpcRequest {
pub jsonrpc: String,
pub id: JsonRpcId,
pub method: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub params: Option<HashMap<String, serde_json::Value>>,
pub jsonrpc: String,
pub id: JsonRpcId,
pub method: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub params: Option<HashMap<String, serde_json::Value>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JsonRpcNotification {
pub jsonrpc: String,
pub method: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub params: Option<HashMap<String, serde_json::Value>>,
pub jsonrpc: String,
pub method: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub params: Option<HashMap<String, serde_json::Value>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JsonRpcError {
pub code: i32,
pub message: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub data: Option<serde_json::Value>,
pub code: i32,
pub message: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub data: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JsonRpcResponse {
pub jsonrpc: String,
pub id: JsonRpcId,
#[serde(skip_serializing_if = "Option::is_none")]
pub result: Option<HashMap<String, serde_json::Value>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub error: Option<JsonRpcError>,
pub jsonrpc: String,
pub id: JsonRpcId,
#[serde(skip_serializing_if = "Option::is_none")]
pub result: Option<HashMap<String, serde_json::Value>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub error: Option<JsonRpcError>,
}

View file

@ -1,6 +1,8 @@
use bytes::Bytes;
use common::configuration::{LlmProvider, ModelAlias};
use common::consts::{ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER, REQUEST_ID_HEADER, TRACE_PARENT_HEADER};
use common::consts::{
ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER, REQUEST_ID_HEADER, TRACE_PARENT_HEADER,
};
use common::traces::TraceCollector;
use hermesllm::apis::openai_responses::InputParam;
use hermesllm::clients::{SupportedAPIsFromClient, SupportedUpstreamAPIs};
@ -14,13 +16,14 @@ use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::{debug, info, warn};
use crate::router::llm_router::RouterService;
use crate::handlers::utils::{create_streaming_response, ObservableStreamProcessor, truncate_message};
use crate::handlers::router_chat::router_chat_get_upstream_model;
use crate::handlers::utils::{
create_streaming_response, truncate_message, ObservableStreamProcessor,
};
use crate::router::llm_router::RouterService;
use crate::state::response_state_processor::ResponsesStateProcessor;
use crate::state::{
StateStorage, StateStorageError,
extract_input_items, retrieve_and_combine_input
extract_input_items, retrieve_and_combine_input, StateStorage, StateStorageError,
};
use crate::tracing::operation_component;
@ -39,7 +42,6 @@ pub async fn llm_chat(
trace_collector: Arc<TraceCollector>,
state_storage: Option<Arc<dyn StateStorage>>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
let request_path = request.uri().path().to_string();
let request_headers = request.headers().clone();
let request_id = request_headers
@ -74,8 +76,14 @@ pub async fn llm_chat(
)) {
Ok(request) => request,
Err(err) => {
warn!("[PLANO_REQ_ID:{}] | FAILURE | Failed to parse request as ProviderRequestType: {}", request_id, err);
let err_msg = format!("[PLANO_REQ_ID:{}] | FAILURE | Failed to parse request: {}", request_id, err);
warn!(
"[PLANO_REQ_ID:{}] | FAILURE | Failed to parse request as ProviderRequestType: {}",
request_id, err
);
let err_msg = format!(
"[PLANO_REQ_ID:{}] | FAILURE | Failed to parse request: {}",
request_id, err
);
let mut bad_request = Response::new(full(err_msg));
*bad_request.status_mut() = StatusCode::BAD_REQUEST;
return Ok(bad_request);
@ -85,7 +93,10 @@ pub async fn llm_chat(
// === v1/responses state management: Extract input items early ===
let mut original_input_items = Vec::new();
let client_api = SupportedAPIsFromClient::from_endpoint(request_path.as_str());
let is_responses_api_client = matches!(client_api, Some(SupportedAPIsFromClient::OpenAIResponsesAPI(_)));
let is_responses_api_client = matches!(
client_api,
Some(SupportedAPIsFromClient::OpenAIResponsesAPI(_))
);
// Model alias resolution: update model field in client_request immediately
// This ensures all downstream objects use the resolved model
@ -96,20 +107,28 @@ pub async fn llm_chat(
// Extract tool names and user message preview for span attributes
let tool_names = client_request.get_tool_names();
let user_message_preview = client_request.get_recent_user_message()
let user_message_preview = client_request
.get_recent_user_message()
.map(|msg| truncate_message(&msg, 50));
client_request.set_model(resolved_model.clone());
if client_request.remove_metadata_key("archgw_preference_config") {
debug!("[PLANO_REQ_ID:{}] Removed archgw_preference_config from metadata", request_id);
debug!(
"[PLANO_REQ_ID:{}] Removed archgw_preference_config from metadata",
request_id
);
}
// === v1/responses state management: Determine upstream API and combine input if needed ===
// Do this BEFORE routing since routing consumes the request
// Only process state if state_storage is configured
let mut should_manage_state = false;
if is_responses_api_client && state_storage.is_some() {
if let ProviderRequestType::ResponsesAPIRequest(ref mut responses_req) = client_request {
if is_responses_api_client {
if let (
ProviderRequestType::ResponsesAPIRequest(ref mut responses_req),
Some(ref state_store),
) = (&mut client_request, &state_storage)
{
// Extract original input once
original_input_items = extract_input_items(&responses_req.input);
@ -120,18 +139,22 @@ pub async fn llm_chat(
&request_path,
&resolved_model,
is_streaming_request,
).await;
)
.await;
let upstream_api = SupportedUpstreamAPIs::from_endpoint(&upstream_path);
// Only manage state if upstream is NOT OpenAIResponsesAPI (needs translation)
should_manage_state = !matches!(upstream_api, Some(SupportedUpstreamAPIs::OpenAIResponsesAPI(_)));
should_manage_state = !matches!(
upstream_api,
Some(SupportedUpstreamAPIs::OpenAIResponsesAPI(_))
);
if should_manage_state {
// Retrieve and combine conversation history if previous_response_id exists
if let Some(ref prev_resp_id) = responses_req.previous_response_id {
match retrieve_and_combine_input(
state_storage.as_ref().unwrap().clone(),
state_store.clone(),
prev_resp_id,
original_input_items, // Pass ownership instead of cloning
)
@ -166,7 +189,10 @@ pub async fn llm_chat(
}
}
} else {
debug!("[PLANO_REQ_ID:{}] | BRIGHT_STAFF | Upstream supports ResponsesAPI natively.", request_id);
debug!(
"[PLANO_REQ_ID:{}] | BRIGHT_STAFF | Upstream supports ResponsesAPI natively.",
request_id
);
}
}
}
@ -177,7 +203,7 @@ pub async fn llm_chat(
// Determine routing using the dedicated router_chat module
let routing_result = match router_chat_get_upstream_model(
router_service,
client_request, // Pass the original request - router_chat will convert it
client_request, // Pass the original request - router_chat will convert it
&request_headers,
trace_collector.clone(),
&traceparent,
@ -257,7 +283,8 @@ pub async fn llm_chat(
user_message_preview,
temperature,
&llm_providers,
).await;
)
.await;
// Create base processor for metrics and tracing
let base_processor = ObservableStreamProcessor::new(
@ -269,7 +296,11 @@ pub async fn llm_chat(
// === v1/responses state management: Wrap with ResponsesStateProcessor ===
// Only wrap if we need to manage state (client is ResponsesAPI AND upstream is NOT ResponsesAPI AND state_storage is configured)
let streaming_response = if should_manage_state && !original_input_items.is_empty() && state_storage.is_some() {
let streaming_response = if let (true, false, Some(state_store)) = (
should_manage_state,
original_input_items.is_empty(),
state_storage,
) {
// Extract Content-Encoding header to handle decompression for state parsing
let content_encoding = response_headers
.get("content-encoding")
@ -279,7 +310,7 @@ pub async fn llm_chat(
// Wrap with state management processor to store state after response completes
let state_processor = ResponsesStateProcessor::new(
base_processor,
state_storage.unwrap(),
state_store,
original_input_items,
resolved_model.clone(),
model_name.clone(),
@ -324,6 +355,7 @@ fn resolve_model_alias(
}
/// Builds the LLM span with all required and optional attributes.
#[allow(clippy::too_many_arguments)]
async fn build_llm_span(
traceparent: &str,
request_path: &str,
@ -337,8 +369,8 @@ async fn build_llm_span(
temperature: Option<f32>,
llm_providers: &Arc<RwLock<Vec<LlmProvider>>>,
) -> common::traces::Span {
use common::traces::{SpanBuilder, SpanKind, parse_traceparent};
use crate::tracing::{http, llm, OperationNameBuilder};
use common::traces::{parse_traceparent, SpanBuilder, SpanKind};
// Calculate the upstream path based on provider configuration
let upstream_path = get_upstream_path(
@ -347,13 +379,14 @@ async fn build_llm_span(
request_path,
resolved_model,
is_streaming,
).await;
)
.await;
// Build operation name showing path transformation if different
let operation_name = if request_path != upstream_path {
OperationNameBuilder::new()
.with_method("POST")
.with_path(&format!("{} >> {}", request_path, upstream_path))
.with_path(format!("{} >> {}", request_path, upstream_path))
.with_target(resolved_model)
.build()
} else {
@ -388,7 +421,8 @@ async fn build_llm_span(
}
if let Some(tools) = tool_names {
let formatted_tools = tools.iter()
let formatted_tools = tools
.iter()
.map(|name| format!("{}(...)", name))
.collect::<Vec<_>>()
.join("\n");
@ -436,8 +470,7 @@ async fn get_provider_info(
// First, try to find by model name or provider name
let provider = providers_lock.iter().find(|p| {
p.model.as_ref().map(|m| m == model_name).unwrap_or(false)
|| p.name == model_name
p.model.as_ref().map(|m| m == model_name).unwrap_or(false) || p.name == model_name
});
if let Some(provider) = provider {
@ -446,9 +479,7 @@ async fn get_provider_info(
return (provider_id, prefix);
}
let default_provider = providers_lock.iter().find(|p| {
p.default.unwrap_or(false)
});
let default_provider = providers_lock.iter().find(|p| p.default.unwrap_or(false));
if let Some(provider) = default_provider {
let provider_id = provider.provider_interface.to_provider_id();

View file

@ -1,13 +1,13 @@
pub mod agent_chat_completions;
pub mod agent_selector;
pub mod llm;
pub mod router_chat;
pub mod models;
pub mod function_calling;
pub mod jsonrpc;
pub mod llm;
pub mod models;
pub mod pipeline_processor;
pub mod response_handler;
pub mod router_chat;
pub mod utils;
pub mod jsonrpc;
#[cfg(test)]
mod integration_tests;

View file

@ -82,6 +82,7 @@ impl PipelineProcessor {
}
/// Record a span for filter execution
#[allow(clippy::too_many_arguments)]
fn record_filter_span(
&self,
collector: &std::sync::Arc<common::traces::TraceCollector>,
@ -132,6 +133,7 @@ impl PipelineProcessor {
}
/// Record a span for MCP protocol interactions
#[allow(clippy::too_many_arguments)]
fn record_agent_filter_span(
&self,
collector: &std::sync::Arc<common::traces::TraceCollector>,
@ -156,12 +158,12 @@ impl PipelineProcessor {
.build();
let mut span_builder = SpanBuilder::new(&operation_name)
.with_span_id(span_id.unwrap_or_else(|| generate_random_span_id()))
.with_span_id(span_id.unwrap_or_else(generate_random_span_id))
.with_kind(SpanKind::Client)
.with_start_time(start_time)
.with_end_time(end_time)
.with_attribute(http::METHOD, "POST")
.with_attribute(http::TARGET, &format!("/mcp ({})", operation.to_string()))
.with_attribute(http::TARGET, format!("/mcp ({})", operation))
.with_attribute("mcp.operation", operation.to_string())
.with_attribute("mcp.agent_id", agent_id.to_string())
.with_attribute(
@ -188,6 +190,7 @@ impl PipelineProcessor {
}
/// Process the filter chain of agents (all except the terminal agent)
#[allow(clippy::too_many_arguments)]
pub async fn process_filter_chain(
&mut self,
chat_history: &[Message],
@ -1023,7 +1026,7 @@ mod tests {
}
});
let sse_body = format!("event: message\ndata: {}\n\n", rpc_body.to_string());
let sse_body = format!("event: message\ndata: {}\n\n", rpc_body);
let mut server = Server::new_async().await;
let _m = server
@ -1061,10 +1064,10 @@ mod tests {
.await;
match result {
Err(PipelineError::ClientError { status, body, .. }) => {
assert_eq!(status, 400);
assert_eq!(body, "bad tool call");
}
Err(PipelineError::ClientError { status, body, .. }) => {
assert_eq!(status, 400);
assert_eq!(body, "bad tool call");
}
_ => panic!("Expected client error when isError flag is set"),
}
}

View file

@ -133,9 +133,7 @@ impl ResponseHandler {
let response_headers = llm_response.headers();
let is_sse_streaming = response_headers
.get(hyper::header::CONTENT_TYPE)
.map_or(false, |v| {
v.to_str().unwrap_or("").contains("text/event-stream")
});
.is_some_and(|v| v.to_str().unwrap_or("").contains("text/event-stream"));
let response_bytes = llm_response
.bytes()
@ -164,7 +162,7 @@ impl ResponseHandler {
match transformed_event.provider_response() {
Ok(provider_response) => {
if let Some(content) = provider_response.content_delta() {
accumulated_text.push_str(&content);
accumulated_text.push_str(content);
} else {
info!("No content delta in provider response");
}
@ -174,7 +172,7 @@ impl ResponseHandler {
}
}
}
return Ok(accumulated_text);
Ok(accumulated_text)
} else {
// If not SSE, treat as regular text response
let response_text = String::from_utf8(response_bytes.to_vec()).map_err(|e| {

View file

@ -1,6 +1,6 @@
use common::configuration::ModelUsagePreference;
use common::consts::{REQUEST_ID_HEADER};
use common::traces::{TraceCollector, SpanKind, SpanBuilder, parse_traceparent};
use common::consts::REQUEST_ID_HEADER;
use common::traces::{parse_traceparent, SpanBuilder, SpanKind, TraceCollector};
use hermesllm::clients::endpoints::SupportedUpstreamAPIs;
use hermesllm::{ProviderRequest, ProviderRequestType};
use hyper::StatusCode;
@ -9,10 +9,10 @@ use std::sync::Arc;
use tracing::{debug, info, warn};
use crate::router::llm_router::RouterService;
use crate::tracing::{OperationNameBuilder, operation_component, http, routing};
use crate::tracing::{http, operation_component, routing, OperationNameBuilder};
pub struct RoutingResult {
pub model_name: String
pub model_name: String,
}
pub struct RoutingError {
@ -24,7 +24,7 @@ impl RoutingError {
pub fn internal_error(message: String) -> Self {
Self {
message,
status_code: StatusCode::INTERNAL_SERVER_ERROR
status_code: StatusCode::INTERNAL_SERVER_ERROR,
}
}
}
@ -52,9 +52,7 @@ pub async fn router_chat_get_upstream_model(
// Convert to ChatCompletionsRequest for routing (regardless of input type)
let chat_request = match ProviderRequestType::try_from((
client_request,
&SupportedUpstreamAPIs::OpenAIChatCompletions(
hermesllm::apis::OpenAIApi::ChatCompletions,
),
&SupportedUpstreamAPIs::OpenAIChatCompletions(hermesllm::apis::OpenAIApi::ChatCompletions),
)) {
Ok(ProviderRequestType::ChatCompletionsRequest(req)) => req,
Ok(
@ -69,7 +67,10 @@ pub async fn router_chat_get_upstream_model(
));
}
Err(err) => {
warn!("Failed to convert request to ChatCompletionsRequest: {}", err);
warn!(
"Failed to convert request to ChatCompletionsRequest: {}",
err
);
return Err(RoutingError::internal_error(format!(
"Failed to convert request: {}",
err
@ -151,9 +152,7 @@ pub async fn router_chat_get_upstream_model(
)
.await;
Ok(RoutingResult {
model_name
})
Ok(RoutingResult { model_name })
}
None => {
// No route determined, use default model from request
@ -176,7 +175,7 @@ pub async fn router_chat_get_upstream_model(
.await;
Ok(RoutingResult {
model_name: default_model
model_name: default_model,
})
}
},
@ -194,9 +193,10 @@ pub async fn router_chat_get_upstream_model(
)
.await;
Err(RoutingError::internal_error(
format!("Failed to determine route: {}", err)
))
Err(RoutingError::internal_error(format!(
"Failed to determine route: {}",
err
)))
}
}
}
@ -230,7 +230,10 @@ async fn record_routing_span(
.with_end_time(std::time::SystemTime::now())
.with_attribute(http::METHOD, "POST")
.with_attribute(http::TARGET, routing_api_path.to_string())
.with_attribute(routing::ROUTE_DETERMINATION_MS, start_time.elapsed().as_millis().to_string());
.with_attribute(
routing::ROUTE_DETERMINATION_MS,
start_time.elapsed().as_millis().to_string(),
);
// Only set parent span ID if it exists (not a root span)
if let Some(parent) = parent_span_id {

View file

@ -1,5 +1,5 @@
use bytes::Bytes;
use common::traces::{Span, Attribute, AttributeValue, TraceCollector, Event};
use common::traces::{Attribute, AttributeValue, Event, Span, TraceCollector};
use http_body_util::combinators::BoxBody;
use http_body_util::StreamBody;
use hyper::body::Frame;
@ -11,7 +11,7 @@ use tokio_stream::StreamExt;
use tracing::warn;
// Import tracing constants
use crate::tracing::{llm, error};
use crate::tracing::{error, llm};
/// Trait for processing streaming chunks
/// Implementors can inject custom logic during streaming (e.g., hallucination detection, logging)
@ -97,7 +97,6 @@ impl StreamProcessor for ObservableStreamProcessor {
},
});
self.span.attributes.push(Attribute {
key: llm::DURATION_MS.to_string(),
value: AttributeValue {
@ -119,11 +118,9 @@ impl StreamProcessor for ObservableStreamProcessor {
if let Ok(start_time_nanos) = self.span.start_time_unix_nano.parse::<u128>() {
// Convert ttft from milliseconds to nanoseconds and add to start time
let event_timestamp = start_time_nanos + (ttft * 1_000_000);
let mut event = Event::new(llm::TIME_TO_FIRST_TOKEN_MS.to_string(), event_timestamp);
event.add_attribute(
llm::TIME_TO_FIRST_TOKEN_MS.to_string(),
ttft.to_string(),
);
let mut event =
Event::new(llm::TIME_TO_FIRST_TOKEN_MS.to_string(), event_timestamp);
event.add_attribute(llm::TIME_TO_FIRST_TOKEN_MS.to_string(), ttft.to_string());
// Initialize events vector if needed
if self.span.events.is_none() {
@ -137,7 +134,8 @@ impl StreamProcessor for ObservableStreamProcessor {
}
// Record the finalized span
self.collector.record_span(&self.service_name, self.span.clone());
self.collector
.record_span(&self.service_name, self.span.clone());
}
fn on_error(&mut self, error_msg: &str) {
@ -173,7 +171,8 @@ impl StreamProcessor for ObservableStreamProcessor {
});
// Record the error span
self.collector.record_span(&self.service_name, self.span.clone());
self.collector
.record_span(&self.service_name, self.span.clone());
}
}

View file

@ -4,13 +4,15 @@ use brightstaff::handlers::llm::llm_chat;
use brightstaff::handlers::models::list_models;
use brightstaff::router::llm_router::RouterService;
use brightstaff::router::plano_orchestrator::OrchestratorService;
use brightstaff::state::StateStorage;
use brightstaff::state::postgresql::PostgreSQLConversationStorage;
use brightstaff::state::memory::MemoryConversationalStorage;
use brightstaff::state::postgresql::PostgreSQLConversationStorage;
use brightstaff::state::StateStorage;
use brightstaff::utils::tracing::init_tracer;
use bytes::Bytes;
use common::configuration::{Agent, Configuration};
use common::consts::{CHAT_COMPLETIONS_PATH, MESSAGES_PATH, OPENAI_RESPONSES_API_PATH, PLANO_ORCHESTRATOR_MODEL_NAME};
use common::consts::{
CHAT_COMPLETIONS_PATH, MESSAGES_PATH, OPENAI_RESPONSES_API_PATH, PLANO_ORCHESTRATOR_MODEL_NAME,
};
use common::traces::TraceCollector;
use http_body_util::{combinators::BoxBody, BodyExt, Empty};
use hyper::body::Incoming;
@ -105,7 +107,6 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
PLANO_ORCHESTRATOR_MODEL_NAME.to_string(),
));
let model_aliases = Arc::new(arch_config.model_aliases.clone());
// Initialize trace collector and start background flusher
@ -127,33 +128,33 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
// Configurable via arch_config.yaml state_storage section
// If not configured, state management is disabled
// Environment variables are substituted by envsubst before config is read
let state_storage: Option<Arc<dyn StateStorage>> = if let Some(storage_config) = &arch_config.state_storage {
let storage: Arc<dyn StateStorage> = match storage_config.storage_type {
common::configuration::StateStorageType::Memory => {
info!("Initialized conversation state storage: Memory");
Arc::new(MemoryConversationalStorage::new())
}
common::configuration::StateStorageType::Postgres => {
let connection_string = storage_config
.connection_string
.as_ref()
.expect("connection_string is required for postgres state_storage");
let state_storage: Option<Arc<dyn StateStorage>> =
if let Some(storage_config) = &arch_config.state_storage {
let storage: Arc<dyn StateStorage> = match storage_config.storage_type {
common::configuration::StateStorageType::Memory => {
info!("Initialized conversation state storage: Memory");
Arc::new(MemoryConversationalStorage::new())
}
common::configuration::StateStorageType::Postgres => {
let connection_string = storage_config
.connection_string
.as_ref()
.expect("connection_string is required for postgres state_storage");
debug!("Postgres connection string (full): {}", connection_string);
info!("Initializing conversation state storage: Postgres");
Arc::new(
PostgreSQLConversationStorage::new(connection_string.clone())
.await
.expect("Failed to initialize Postgres state storage"),
)
}
debug!("Postgres connection string (full): {}", connection_string);
info!("Initializing conversation state storage: Postgres");
Arc::new(
PostgreSQLConversationStorage::new(connection_string.clone())
.await
.expect("Failed to initialize Postgres state storage"),
)
}
};
Some(storage)
} else {
info!("No state_storage configured - conversation state management disabled");
None
};
Some(storage)
} else {
info!("No state_storage configured - conversation state management disabled");
None
};
loop {
let (stream, _) = listener.accept().await?;
@ -208,12 +209,22 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
}
}
match (req.method(), path) {
(&Method::POST, CHAT_COMPLETIONS_PATH | MESSAGES_PATH | OPENAI_RESPONSES_API_PATH) => {
let fully_qualified_url =
format!("{}{}", llm_provider_url, path);
llm_chat(req, router_service, fully_qualified_url, model_aliases, llm_providers, trace_collector, state_storage)
.with_context(parent_cx)
.await
(
&Method::POST,
CHAT_COMPLETIONS_PATH | MESSAGES_PATH | OPENAI_RESPONSES_API_PATH,
) => {
let fully_qualified_url = format!("{}{}", llm_provider_url, path);
llm_chat(
req,
router_service,
fully_qualified_url,
model_aliases,
llm_providers,
trace_collector,
state_storage,
)
.with_context(parent_cx)
.await
}
(&Method::POST, "/function_calling") => {
let fully_qualified_url =

View file

@ -2,7 +2,7 @@ use std::collections::HashMap;
use common::configuration::{AgentUsagePreference, OrchestrationPreference};
use hermesllm::apis::openai::{ChatCompletionsRequest, Message, MessageContent, Role};
use serde::{Deserialize, Serialize, ser::Serialize as SerializeTrait};
use serde::{ser::Serialize as SerializeTrait, Deserialize, Serialize};
use tracing::{debug, warn};
use super::orchestrator_model::{OrchestratorModel, OrchestratorModelError};
@ -144,7 +144,7 @@ impl OrchestratorModelV1 {
// Format routes: each route as JSON on its own line with standard spacing
let agent_orchestration_json_str = agent_orchestration_values
.iter()
.map(|pref| to_spaced_json(pref))
.map(to_spaced_json)
.collect::<Vec<String>>()
.join("\n");
let agent_orchestration_to_model_map: HashMap<String, String> = agent_orchestrations
@ -238,24 +238,26 @@ impl OrchestratorModel for OrchestratorModelV1 {
let selected_conversation_list = selected_messages_list_reversed
.iter()
.rev()
.map(|message| {
Message {
role: message.role.clone(),
content: MessageContent::Text(message.content.to_string()),
name: None,
tool_calls: None,
tool_call_id: None,
}
.map(|message| Message {
role: message.role.clone(),
content: MessageContent::Text(message.content.to_string()),
name: None,
tool_calls: None,
tool_call_id: None,
})
.collect::<Vec<Message>>();
// Generate the orchestrator request message based on the usage preferences.
// If preferences are passed in request then we use them;
// Otherwise, we use the default orchestration modelpreferences.
let orchestrator_message = match convert_to_orchestrator_preferences(usage_preferences_from_request) {
Some(prefs) => generate_orchestrator_message(&prefs, &selected_conversation_list),
None => generate_orchestrator_message(&self.agent_orchestration_json_str, &selected_conversation_list),
};
let orchestrator_message =
match convert_to_orchestrator_preferences(usage_preferences_from_request) {
Some(prefs) => generate_orchestrator_message(&prefs, &selected_conversation_list),
None => generate_orchestrator_message(
&self.agent_orchestration_json_str,
&selected_conversation_list,
),
};
ChatCompletionsRequest {
model: self.orchestration_model.clone(),
@ -280,7 +282,8 @@ impl OrchestratorModel for OrchestratorModelV1 {
return Ok(None);
}
let orchestrator_resp_fixed = fix_json_response(content);
let orchestrator_response: AgentOrchestratorResponse = serde_json::from_str(orchestrator_resp_fixed.as_str())?;
let orchestrator_response: AgentOrchestratorResponse =
serde_json::from_str(orchestrator_resp_fixed.as_str())?;
let selected_routes = orchestrator_response.route.unwrap_or_default();
@ -320,7 +323,11 @@ impl OrchestratorModel for OrchestratorModelV1 {
} else {
// If no usage preferences are passed in request then use the default orchestration model preferences
for selected_route in valid_routes {
if let Some(model) = self.agent_orchestration_to_model_map.get(&selected_route).cloned() {
if let Some(model) = self
.agent_orchestration_to_model_map
.get(&selected_route)
.cloned()
{
result.push((selected_route, model));
} else {
warn!(
@ -375,7 +382,7 @@ fn convert_to_orchestrator_preferences(
// Format routes: each route as JSON on its own line with standard spacing
let routes_str = orchestration_preferences
.iter()
.map(|pref| to_spaced_json(pref))
.map(to_spaced_json)
.collect::<Vec<String>>()
.join("\n");
@ -425,7 +432,10 @@ mod tests {
// CRITICAL: Test that colons inside string values are NOT modified
let with_colon = serde_json::json!({"name": "foo:bar", "url": "http://example.com"});
let result = to_spaced_json(&with_colon);
assert_eq!(result, r#"{"name": "foo:bar", "url": "http://example.com"}"#);
assert_eq!(
result,
r#"{"name": "foo:bar", "url": "http://example.com"}"#
);
// Test empty object and array
let empty_obj = serde_json::json!({});
@ -446,7 +456,8 @@ mod tests {
});
let result = to_spaced_json(&complex);
// Verify URLs with colons are preserved correctly
assert!(result.contains(r#""urls": ["https://api.example.com:8080/path", "file:///local/path"]"#));
assert!(result
.contains(r#""urls": ["https://api.example.com:8080/path", "file:///local/path"]"#));
// Verify spacing format
assert!(result.contains(r#""type": "object""#));
assert!(result.contains(r#""properties": {}"#));
@ -497,10 +508,16 @@ If no routes are needed, return an empty list for `route`.
]
}
"#;
let agent_orchestrations =
serde_json::from_str::<HashMap<String, Vec<OrchestrationPreference>>>(orchestrations_str).unwrap();
let agent_orchestrations = serde_json::from_str::<
HashMap<String, Vec<OrchestrationPreference>>,
>(orchestrations_str)
.unwrap();
let orchestration_model = "test-model".to_string();
let orchestrator = OrchestratorModelV1::new(agent_orchestrations, orchestration_model.clone(), usize::MAX);
let orchestrator = OrchestratorModelV1::new(
agent_orchestrations,
orchestration_model.clone(),
usize::MAX,
);
let conversation_str = r#"
[
@ -568,7 +585,11 @@ If no routes are needed, return an empty list for `route`.
// Empty orchestrations map - not used when usage_preferences are provided
let agent_orchestrations: HashMap<String, Vec<OrchestrationPreference>> = HashMap::new();
let orchestration_model = "test-model".to_string();
let orchestrator = OrchestratorModelV1::new(agent_orchestrations, orchestration_model.clone(), usize::MAX);
let orchestrator = OrchestratorModelV1::new(
agent_orchestrations,
orchestration_model.clone(),
usize::MAX,
);
let conversation_str = r#"
[
@ -640,10 +661,13 @@ If no routes are needed, return an empty list for `route`.
]
}
"#;
let agent_orchestrations =
serde_json::from_str::<HashMap<String, Vec<OrchestrationPreference>>>(orchestrations_str).unwrap();
let agent_orchestrations = serde_json::from_str::<
HashMap<String, Vec<OrchestrationPreference>>,
>(orchestrations_str)
.unwrap();
let orchestration_model = "test-model".to_string();
let orchestrator = OrchestratorModelV1::new(agent_orchestrations, orchestration_model.clone(), 235);
let orchestrator =
OrchestratorModelV1::new(agent_orchestrations, orchestration_model.clone(), 235);
let conversation_str = r#"
[
@ -709,11 +733,14 @@ If no routes are needed, return an empty list for `route`.
]
}
"#;
let agent_orchestrations =
serde_json::from_str::<HashMap<String, Vec<OrchestrationPreference>>>(orchestrations_str).unwrap();
let agent_orchestrations = serde_json::from_str::<
HashMap<String, Vec<OrchestrationPreference>>,
>(orchestrations_str)
.unwrap();
let orchestration_model = "test-model".to_string();
let orchestrator = OrchestratorModelV1::new(agent_orchestrations, orchestration_model.clone(), 200);
let orchestrator =
OrchestratorModelV1::new(agent_orchestrations, orchestration_model.clone(), 200);
let conversation_str = r#"
[
@ -787,10 +814,13 @@ If no routes are needed, return an empty list for `route`.
]
}
"#;
let agent_orchestrations =
serde_json::from_str::<HashMap<String, Vec<OrchestrationPreference>>>(orchestrations_str).unwrap();
let agent_orchestrations = serde_json::from_str::<
HashMap<String, Vec<OrchestrationPreference>>,
>(orchestrations_str)
.unwrap();
let orchestration_model = "test-model".to_string();
let orchestrator = OrchestratorModelV1::new(agent_orchestrations, orchestration_model.clone(), 230);
let orchestrator =
OrchestratorModelV1::new(agent_orchestrations, orchestration_model.clone(), 230);
let conversation_str = r#"
[
@ -871,10 +901,16 @@ If no routes are needed, return an empty list for `route`.
]
}
"#;
let agent_orchestrations =
serde_json::from_str::<HashMap<String, Vec<OrchestrationPreference>>>(orchestrations_str).unwrap();
let agent_orchestrations = serde_json::from_str::<
HashMap<String, Vec<OrchestrationPreference>>,
>(orchestrations_str)
.unwrap();
let orchestration_model = "test-model".to_string();
let orchestrator = OrchestratorModelV1::new(agent_orchestrations, orchestration_model.clone(), usize::MAX);
let orchestrator = OrchestratorModelV1::new(
agent_orchestrations,
orchestration_model.clone(),
usize::MAX,
);
let conversation_str = r#"
[
@ -957,10 +993,16 @@ If no routes are needed, return an empty list for `route`.
]
}
"#;
let agent_orchestrations =
serde_json::from_str::<HashMap<String, Vec<OrchestrationPreference>>>(orchestrations_str).unwrap();
let agent_orchestrations = serde_json::from_str::<
HashMap<String, Vec<OrchestrationPreference>>,
>(orchestrations_str)
.unwrap();
let orchestration_model = "test-model".to_string();
let orchestrator = OrchestratorModelV1::new(agent_orchestrations, orchestration_model.clone(), usize::MAX);
let orchestrator = OrchestratorModelV1::new(
agent_orchestrations,
orchestration_model.clone(),
usize::MAX,
);
let conversation_str = r#"
[
@ -1034,10 +1076,13 @@ If no routes are needed, return an empty list for `route`.
]
}
"#;
let agent_orchestrations =
serde_json::from_str::<HashMap<String, Vec<OrchestrationPreference>>>(orchestrations_str).unwrap();
let agent_orchestrations = serde_json::from_str::<
HashMap<String, Vec<OrchestrationPreference>>,
>(orchestrations_str)
.unwrap();
let orchestrator = OrchestratorModelV1::new(agent_orchestrations, "test-model".to_string(), 2000);
let orchestrator =
OrchestratorModelV1::new(agent_orchestrations, "test-model".to_string(), 2000);
// Case 1: Valid JSON with single route in array
let input = r#"{"route": ["Image generation"]}"#;

View file

@ -34,10 +34,7 @@ pub enum OrchestrationError {
pub type Result<T> = std::result::Result<T, OrchestrationError>;
impl OrchestratorService {
pub fn new(
orchestrator_url: String,
orchestration_model_name: String,
) -> Self {
pub fn new(orchestrator_url: String, orchestration_model_name: String) -> Self {
// Empty agent orchestrations - will be provided via usage_preferences in requests
let agent_orchestrations: HashMap<String, Vec<OrchestrationPreference>> = HashMap::new();

View file

@ -85,13 +85,19 @@ impl StateStorage for MemoryConversationalStorage {
#[cfg(test)]
mod tests {
use super::*;
use hermesllm::apis::openai_responses::{InputItem, InputMessage, MessageRole, InputContent, MessageContent};
use hermesllm::apis::openai_responses::{
InputContent, InputItem, InputMessage, MessageContent, MessageRole,
};
fn create_test_state(response_id: &str, num_messages: usize) -> OpenAIConversationState {
let mut input_items = Vec::new();
for i in 0..num_messages {
input_items.push(InputItem::Message(InputMessage {
role: if i % 2 == 0 { MessageRole::User } else { MessageRole::Assistant },
role: if i % 2 == 0 {
MessageRole::User
} else {
MessageRole::Assistant
},
content: MessageContent::Items(vec![InputContent::InputText {
text: format!("Message {}", i),
}]),
@ -252,7 +258,9 @@ mod tests {
let merged = storage.merge(&prev_state, current_input);
// Verify order: prev messages first, then current
let InputItem::Message(msg) = &merged[0] else { panic!("Expected Message") };
let InputItem::Message(msg) = &merged[0] else {
panic!("Expected Message")
};
match &msg.content {
MessageContent::Items(items) => match &items[0] {
InputContent::InputText { text } => assert_eq!(text, "Message 0"),
@ -261,7 +269,9 @@ mod tests {
_ => panic!("Expected MessageContent::Items"),
}
let InputItem::Message(msg) = &merged[2] else { panic!("Expected Message") };
let InputItem::Message(msg) = &merged[2] else {
panic!("Expected Message")
};
match &msg.content {
MessageContent::Items(items) => match &items[0] {
InputContent::InputText { text } => assert_eq!(text, "Message 2"),
@ -404,7 +414,8 @@ mod tests {
let current_input = vec![InputItem::Message(InputMessage {
role: MessageRole::User,
content: MessageContent::Items(vec![InputContent::InputText {
text: "Function result: {\"temperature\": 72, \"condition\": \"sunny\"}".to_string(),
text: "Function result: {\"temperature\": 72, \"condition\": \"sunny\"}"
.to_string(),
}]),
})];
@ -415,7 +426,9 @@ mod tests {
assert_eq!(merged.len(), 3);
// Verify the order and content
let InputItem::Message(msg1) = &merged[0] else { panic!("Expected Message") };
let InputItem::Message(msg1) = &merged[0] else {
panic!("Expected Message")
};
assert!(matches!(msg1.role, MessageRole::User));
match &msg1.content {
MessageContent::Items(items) => match &items[0] {
@ -427,7 +440,9 @@ mod tests {
_ => panic!("Expected MessageContent::Items"),
}
let InputItem::Message(msg2) = &merged[1] else { panic!("Expected Message") };
let InputItem::Message(msg2) = &merged[1] else {
panic!("Expected Message")
};
assert!(matches!(msg2.role, MessageRole::Assistant));
match &msg2.content {
MessageContent::Items(items) => match &items[0] {
@ -439,7 +454,9 @@ mod tests {
_ => panic!("Expected MessageContent::Items"),
}
let InputItem::Message(msg3) = &merged[2] else { panic!("Expected Message") };
let InputItem::Message(msg3) = &merged[2] else {
panic!("Expected Message")
};
assert!(matches!(msg3.role, MessageRole::User));
match &msg3.content {
MessageContent::Items(items) => match &items[0] {
@ -508,11 +525,15 @@ mod tests {
assert_eq!(merged.len(), 5);
// Verify first item is original user message
let InputItem::Message(first) = &merged[0] else { panic!("Expected Message") };
let InputItem::Message(first) = &merged[0] else {
panic!("Expected Message")
};
assert!(matches!(first.role, MessageRole::User));
// Verify last two are function outputs
let InputItem::Message(second_last) = &merged[3] else { panic!("Expected Message") };
let InputItem::Message(second_last) = &merged[3] else {
panic!("Expected Message")
};
assert!(matches!(second_last.role, MessageRole::User));
match &second_last.content {
MessageContent::Items(items) => match &items[0] {
@ -522,7 +543,9 @@ mod tests {
_ => panic!("Expected MessageContent::Items"),
}
let InputItem::Message(last) = &merged[4] else { panic!("Expected Message") };
let InputItem::Message(last) = &merged[4] else {
panic!("Expected Message")
};
assert!(matches!(last.role, MessageRole::User));
match &last.content {
MessageContent::Items(items) => match &items[0] {
@ -590,7 +613,9 @@ mod tests {
assert_eq!(merged.len(), 5);
// Verify the entire conversation flow is preserved
let InputItem::Message(first) = &merged[0] else { panic!("Expected Message") };
let InputItem::Message(first) = &merged[0] else {
panic!("Expected Message")
};
match &first.content {
MessageContent::Items(items) => match &items[0] {
InputContent::InputText { text } => assert!(text.contains("What's the weather")),
@ -599,7 +624,9 @@ mod tests {
_ => panic!("Expected MessageContent::Items"),
}
let InputItem::Message(last) = &merged[4] else { panic!("Expected Message") };
let InputItem::Message(last) = &merged[4] else {
panic!("Expected Message")
};
match &last.content {
MessageContent::Items(items) => match &items[0] {
InputContent::InputText { text } => assert!(text.contains("umbrella")),

View file

@ -1,14 +1,16 @@
use async_trait::async_trait;
use hermesllm::apis::openai_responses::{InputItem, InputMessage, InputContent, MessageContent, MessageRole, InputParam};
use hermesllm::apis::openai_responses::{
InputContent, InputItem, InputMessage, InputParam, MessageContent, MessageRole,
};
use serde::{Deserialize, Serialize};
use std::error::Error;
use std::fmt;
use std::sync::Arc;
use tracing::{debug};
use tracing::debug;
pub mod memory;
pub mod response_state_processor;
pub mod postgresql;
pub mod response_state_processor;
/// Represents the conversational state for a v1/responses request
/// Contains the complete input/output history that can be restored
@ -47,7 +49,9 @@ pub enum StateStorageError {
impl fmt::Display for StateStorageError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
StateStorageError::NotFound(id) => write!(f, "Conversation state not found for response_id: {}", id),
StateStorageError::NotFound(id) => {
write!(f, "Conversation state not found for response_id: {}", id)
}
StateStorageError::StorageError(msg) => write!(f, "Storage error: {}", msg),
StateStorageError::SerializationError(msg) => write!(f, "Serialization error: {}", msg),
}
@ -96,8 +100,6 @@ pub trait StateStorage: Send + Sync {
}
}
/// Storage backend type enum
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum StorageBackend {
@ -106,7 +108,7 @@ pub enum StorageBackend {
}
impl StorageBackend {
pub fn from_str(s: &str) -> Option<Self> {
pub fn parse_backend(s: &str) -> Option<Self> {
match s.to_lowercase().as_str() {
"memory" => Some(StorageBackend::Memory),
"supabase" => Some(StorageBackend::Supabase),
@ -139,7 +141,6 @@ pub async fn retrieve_and_combine_input(
previous_response_id: &str,
current_input: Vec<InputItem>,
) -> Result<Vec<InputItem>, StateStorageError> {
// First get the previous state
let prev_state = storage.get(previous_response_id).await?;
let combined_input = storage.merge(&prev_state, current_input);

View file

@ -149,13 +149,12 @@ impl StateStorage for PostgreSQLConversationStorage {
let provider: String = row.get("provider");
// Deserialize input_items from JSONB
let input_items =
serde_json::from_value(input_items_json).map_err(|e| {
StateStorageError::StorageError(format!(
"Failed to deserialize input_items: {}",
e
))
})?;
let input_items = serde_json::from_value(input_items_json).map_err(|e| {
StateStorageError::StorageError(format!(
"Failed to deserialize input_items: {}",
e
))
})?;
Ok(OpenAIConversationState {
response_id,
@ -230,7 +229,9 @@ Run that SQL file against your database before using this storage backend.
#[cfg(test)]
mod tests {
use super::*;
use hermesllm::apis::openai_responses::{InputContent, InputItem, InputMessage, MessageContent, MessageRole};
use hermesllm::apis::openai_responses::{
InputContent, InputItem, InputMessage, MessageContent, MessageRole,
};
fn create_test_state(response_id: &str) -> OpenAIConversationState {
OpenAIConversationState {
@ -320,7 +321,10 @@ mod tests {
let result = storage.get("nonexistent_id").await;
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), StateStorageError::NotFound(_)));
assert!(matches!(
result.unwrap_err(),
StateStorageError::NotFound(_)
));
}
#[tokio::test]
@ -372,7 +376,10 @@ mod tests {
let result = storage.delete("nonexistent_id").await;
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), StateStorageError::NotFound(_)));
assert!(matches!(
result.unwrap_err(),
StateStorageError::NotFound(_)
));
}
#[tokio::test]
@ -423,9 +430,13 @@ mod tests {
println!("✅ Data written to Supabase!");
println!("Check your Supabase dashboard:");
println!(" SELECT * FROM conversation_states WHERE response_id = 'manual_test_verification';");
println!(
" SELECT * FROM conversation_states WHERE response_id = 'manual_test_verification';"
);
println!("\nTo cleanup, run:");
println!(" DELETE FROM conversation_states WHERE response_id = 'manual_test_verification';");
println!(
" DELETE FROM conversation_states WHERE response_id = 'manual_test_verification';"
);
// DON'T cleanup - leave it for manual verification
}

View file

@ -1,13 +1,11 @@
use bytes::Bytes;
use flate2::read::GzDecoder;
use hermesllm::apis::openai_responses::{
InputItem, OutputItem, ResponsesAPIStreamEvent,
};
use hermesllm::apis::openai_responses::{InputItem, OutputItem, ResponsesAPIStreamEvent};
use hermesllm::apis::streaming_shapes::sse::SseStreamIter;
use hermesllm::transforms::response::output_to_input::outputs_to_inputs;
use std::io::Read;
use std::sync::Arc;
use tracing::{info, debug, warn};
use tracing::{debug, info, warn};
use crate::handlers::utils::StreamProcessor;
use crate::state::{OpenAIConversationState, StateStorage};
@ -53,6 +51,7 @@ pub struct ResponsesStateProcessor<P: StreamProcessor> {
}
impl<P: StreamProcessor> ResponsesStateProcessor<P> {
#[allow(clippy::too_many_arguments)]
pub fn new(
inner: P,
storage: Arc<dyn StateStorage>,
@ -139,20 +138,19 @@ impl<P: StreamProcessor> ResponsesStateProcessor<P> {
for event in sse_iter {
// Only process data lines (skip event-only lines)
if let Some(data_str) = &event.data {
// Try to parse as ResponsesAPIStreamEvent
if let Ok(stream_event) = serde_json::from_str::<ResponsesAPIStreamEvent>(data_str) {
// Check if this is a ResponseCompleted event
if let ResponsesAPIStreamEvent::ResponseCompleted { response, .. } = stream_event {
info!(
"[PLANO_REQ_ID:{}] | STATE_PROCESSOR | Captured streaming response.completed: response_id={}, output_items={}",
self.request_id,
response.id,
response.output.len()
);
self.response_id = Some(response.id.clone());
self.output_items = Some(response.output.clone());
return; // Found what we need, exit early
}
// Try to parse as ResponsesAPIStreamEvent and check if it's a ResponseCompleted event
if let Ok(ResponsesAPIStreamEvent::ResponseCompleted { response, .. }) =
serde_json::from_str::<ResponsesAPIStreamEvent>(data_str)
{
info!(
"[PLANO_REQ_ID:{}] | STATE_PROCESSOR | Captured streaming response.completed: response_id={}, output_items={}",
self.request_id,
response.id,
response.output.len()
);
self.response_id = Some(response.id.clone());
self.output_items = Some(response.output.clone());
return; // Found what we need, exit early
}
}
}
@ -172,7 +170,9 @@ impl<P: StreamProcessor> ResponsesStateProcessor<P> {
let decompressed = self.decompress_buffer();
// Parse complete JSON response
match serde_json::from_slice::<hermesllm::apis::openai_responses::ResponsesAPIResponse>(&decompressed) {
match serde_json::from_slice::<hermesllm::apis::openai_responses::ResponsesAPIResponse>(
&decompressed,
) {
Ok(response) => {
info!(
"[PLANO_REQ_ID:{}] | STATE_PROCESSOR | Captured non-streaming response: response_id={}, output_items={}",

View file

@ -2,11 +2,9 @@
///
/// This module defines standard attribute keys following OTEL semantic conventions.
/// See: https://opentelemetry.io/docs/specs/semconv/
// =============================================================================
// Span Attributes - HTTP
// =============================================================================
/// Semantic conventions for HTTP-related span attributes
pub mod http {
/// HTTP request method

View file

@ -1,3 +1,3 @@
mod constants;
pub use constants::{OperationNameBuilder, operation_component, http, llm, error, routing};
pub use constants::{error, http, llm, operation_component, routing, OperationNameBuilder};

View file

@ -295,11 +295,14 @@ impl serde::Serialize for OrchestrationPreference {
let mut state = serializer.serialize_struct("OrchestrationPreference", 3)?;
state.serialize_field("name", &self.name)?;
state.serialize_field("description", &self.description)?;
state.serialize_field("parameters", &serde_json::json!({
"type": "object",
"properties": {},
"required": []
}))?;
state.serialize_field(
"parameters",
&serde_json::json!({
"type": "object",
"properties": {},
"required": []
}),
)?;
state.end()
}
}
@ -489,7 +492,10 @@ mod test {
assert_eq!(config.version, "v0.3.0");
if let Some(prompt_targets) = &config.prompt_targets {
assert!(!prompt_targets.is_empty(), "prompt_targets should not be empty if present");
assert!(
!prompt_targets.is_empty(),
"prompt_targets should not be empty if present"
);
}
if let Some(tracing) = config.tracing.as_ref() {
@ -510,19 +516,48 @@ mod test {
.expect("reference config file not found");
let config: super::Configuration = serde_yaml::from_str(&ref_config).unwrap();
if let Some(prompt_targets) = &config.prompt_targets {
if let Some(prompt_target) = prompt_targets.iter().find(|p| p.name == "reboot_network_device") {
if let Some(prompt_target) = prompt_targets
.iter()
.find(|p| p.name == "reboot_network_device")
{
let chat_completion_tool: super::ChatCompletionTool = prompt_target.into();
assert_eq!(chat_completion_tool.tool_type, ToolType::Function);
assert_eq!(chat_completion_tool.function.name, "reboot_network_device");
assert_eq!(chat_completion_tool.function.description, "Reboot a specific network device");
assert_eq!(
chat_completion_tool.function.description,
"Reboot a specific network device"
);
assert_eq!(chat_completion_tool.function.parameters.properties.len(), 2);
assert!(chat_completion_tool.function.parameters.properties.contains_key("device_id"));
let device_id_param = chat_completion_tool.function.parameters.properties.get("device_id").unwrap();
assert_eq!(device_id_param.parameter_type, crate::api::open_ai::ParameterType::String);
assert_eq!(device_id_param.description, "Identifier of the network device to reboot.".to_string());
assert!(chat_completion_tool
.function
.parameters
.properties
.contains_key("device_id"));
let device_id_param = chat_completion_tool
.function
.parameters
.properties
.get("device_id")
.unwrap();
assert_eq!(
device_id_param.parameter_type,
crate::api::open_ai::ParameterType::String
);
assert_eq!(
device_id_param.description,
"Identifier of the network device to reboot.".to_string()
);
assert_eq!(device_id_param.required, Some(true));
let confirmation_param = chat_completion_tool.function.parameters.properties.get("confirmation").unwrap();
assert_eq!(confirmation_param.parameter_type, crate::api::open_ai::ParameterType::Bool);
let confirmation_param = chat_completion_tool
.function
.parameters
.properties
.get("confirmation")
.unwrap();
assert_eq!(
confirmation_param.parameter_type,
crate::api::open_ai::ParameterType::Bool
);
}
}
}

View file

@ -32,6 +32,6 @@ pub const OTEL_COLLECTOR_HTTP: &str = "opentelemetry_collector_http";
pub const OTEL_POST_PATH: &str = "/v1/traces";
pub const LLM_ROUTE_HEADER: &str = "x-arch-llm-route";
pub const ENVOY_RETRY_HEADER: &str = "x-envoy-max-retries";
pub const BRIGHT_STAFF_SERVICE_NAME : &str = "brightstaff";
pub const BRIGHT_STAFF_SERVICE_NAME: &str = "brightstaff";
pub const PLANO_ORCHESTRATOR_MODEL_NAME: &str = "Plano-Orchestrator";
pub const ARCH_FC_CLUSTER: &str = "arch";

View file

@ -10,6 +10,6 @@ pub mod ratelimit;
pub mod routing;
pub mod stats;
pub mod tokenizer;
pub mod tracing;
pub mod traces;
pub mod tracing;
pub mod utils;

View file

@ -41,7 +41,8 @@ pub fn get_llm_provider(
llm_providers
.iter()
.filter(|(_, provider)| {
provider.model
provider
.model
.as_ref()
.map(|m| !m.starts_with("Arch"))
.unwrap_or(true)

View file

@ -1,5 +1,5 @@
use super::shapes::Span;
use super::resource_span_builder::ResourceSpanBuilder;
use super::shapes::Span;
use std::collections::{HashMap, VecDeque};
use std::sync::Arc;
use tokio::sync::Mutex;
@ -160,7 +160,11 @@ impl TraceCollector {
}
let total_spans: usize = service_batches.iter().map(|(_, spans)| spans.len()).sum();
debug!("Flushing {} spans across {} services to OTEL collector", total_spans, service_batches.len());
debug!(
"Flushing {} spans across {} services to OTEL collector",
total_spans,
service_batches.len()
);
// Build canonical OTEL payload structure - one ResourceSpan per service
let resource_spans = self.build_resource_spans(service_batches);
@ -178,7 +182,10 @@ impl TraceCollector {
}
/// Build OTEL-compliant resource spans from collected spans, one ResourceSpan per service
fn build_resource_spans(&self, service_batches: Vec<(String, Vec<Span>)>) -> Vec<super::shapes::ResourceSpan> {
fn build_resource_spans(
&self,
service_batches: Vec<(String, Vec<Span>)>,
) -> Vec<super::shapes::ResourceSpan> {
service_batches
.into_iter()
.map(|(service_name, spans)| {

View file

@ -1,7 +1,6 @@
/// OpenTelemetry semantic convention constants for tracing
///
/// These constants ensure consistency across the codebase and prevent typos
/// Resource attribute keys following OTEL semantic conventions
pub mod resource {
/// Logical name of the service

View file

@ -1,9 +1,9 @@
// Original tracing types (OTEL structures)
mod shapes;
// New tracing utilities
mod span_builder;
mod resource_span_builder;
mod constants;
mod resource_span_builder;
mod span_builder;
#[cfg(feature = "trace-collection")]
mod collector;
@ -13,14 +13,14 @@ mod tests;
// Re-export original types
pub use shapes::{
Span, Event, Traceparent, TraceparentNewError,
ResourceSpan, Resource, ScopeSpan, Scope, Attribute, AttributeValue,
Attribute, AttributeValue, Event, Resource, ResourceSpan, Scope, ScopeSpan, Span, Traceparent,
TraceparentNewError,
};
// Re-export new utilities
pub use span_builder::{SpanBuilder, SpanKind, generate_random_span_id};
pub use resource_span_builder::ResourceSpanBuilder;
pub use constants::*;
pub use resource_span_builder::ResourceSpanBuilder;
pub use span_builder::{generate_random_span_id, SpanBuilder, SpanKind};
#[cfg(feature = "trace-collection")]
pub use collector::{TraceCollector, parse_traceparent};
pub use collector::{parse_traceparent, TraceCollector};

View file

@ -1,5 +1,5 @@
use super::shapes::{ResourceSpan, Resource, ScopeSpan, Scope, Span, Attribute, AttributeValue};
use super::constants::{resource, scope};
use super::shapes::{Attribute, AttributeValue, Resource, ResourceSpan, Scope, ScopeSpan, Span};
use std::collections::HashMap;
/// Builder for creating OTEL ResourceSpan structures
@ -26,7 +26,11 @@ impl ResourceSpanBuilder {
}
/// Add a resource attribute (e.g., deployment.environment, host.name)
pub fn with_resource_attribute(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
pub fn with_resource_attribute(
mut self,
key: impl Into<String>,
value: impl Into<String>,
) -> Self {
self.resource_attributes.insert(key.into(), value.into());
self
}
@ -58,14 +62,12 @@ impl ResourceSpanBuilder {
/// Build the ResourceSpan
pub fn build(self) -> ResourceSpan {
// Build resource attributes
let mut attributes = vec![
Attribute {
key: resource::SERVICE_NAME.to_string(),
value: AttributeValue {
string_value: Some(self.service_name),
},
}
];
let mut attributes = vec![Attribute {
key: resource::SERVICE_NAME.to_string(),
value: AttributeValue {
string_value: Some(self.service_name),
},
}];
// Add custom resource attributes
for (key, value) in self.resource_attributes {

View file

@ -1,4 +1,4 @@
use super::shapes::{Span, Attribute, AttributeValue};
use super::shapes::{Attribute, AttributeValue, Span};
use std::collections::HashMap;
use std::time::SystemTime;
@ -116,10 +116,11 @@ impl SpanBuilder {
let end_nanos = system_time_to_nanos(end_time);
// Generate trace_id if not provided
let trace_id = self.trace_id.unwrap_or_else(|| generate_random_trace_id());
let trace_id = self.trace_id.unwrap_or_else(generate_random_trace_id);
// Create attributes in OTEL format
let attributes: Vec<Attribute> = self.attributes
let attributes: Vec<Attribute> = self
.attributes
.into_iter()
.map(|(key, value)| Attribute {
key,
@ -132,7 +133,7 @@ impl SpanBuilder {
// Build span directly without going through Span::new()
Span {
trace_id,
span_id: self.span_id.unwrap_or_else(|| generate_random_span_id()),
span_id: self.span_id.unwrap_or_else(generate_random_span_id),
parent_span_id: self.parent_span_id,
name: self.name,
start_time_unix_nano: format!("{}", start_nanos),

View file

@ -21,10 +21,7 @@ use tokio::sync::RwLock;
type SharedTraces = Arc<RwLock<Vec<Value>>>;
/// POST /v1/traces - capture incoming OTLP payload
async fn post_traces(
State(traces): State<SharedTraces>,
Json(payload): Json<Value>,
) -> StatusCode {
async fn post_traces(State(traces): State<SharedTraces>, Json(payload): Json<Value>) -> StatusCode {
traces.write().await.push(payload);
StatusCode::OK
}
@ -67,9 +64,7 @@ impl MockOtelCollector {
let address = format!("http://127.0.0.1:{}", addr.port());
let server_handle = tokio::spawn(async move {
axum::serve(listener, app)
.await
.expect("Server failed");
axum::serve(listener, app).await.expect("Server failed");
});
// Give server a moment to start

View file

@ -36,9 +36,12 @@ fn extract_spans(payloads: &[Value]) -> Vec<&Value> {
for payload in payloads {
if let Some(resource_spans) = payload.get("resourceSpans").and_then(|v| v.as_array()) {
for resource_span in resource_spans {
if let Some(scope_spans) = resource_span.get("scopeSpans").and_then(|v| v.as_array()) {
if let Some(scope_spans) =
resource_span.get("scopeSpans").and_then(|v| v.as_array())
{
for scope_span in scope_spans {
if let Some(span_list) = scope_span.get("spans").and_then(|v| v.as_array()) {
if let Some(span_list) = scope_span.get("spans").and_then(|v| v.as_array())
{
spans.extend(span_list.iter());
}
}
@ -54,9 +57,9 @@ fn get_string_attr<'a>(span: &'a Value, key: &str) -> Option<&'a str> {
span.get("attributes")
.and_then(|attrs| attrs.as_array())
.and_then(|attrs| {
attrs.iter().find(|attr| {
attr.get("key").and_then(|k| k.as_str()) == Some(key)
})
attrs
.iter()
.find(|attr| attr.get("key").and_then(|k| k.as_str()) == Some(key))
})
.and_then(|attr| attr.get("value"))
.and_then(|v| v.get("stringValue"))
@ -70,7 +73,10 @@ async fn test_llm_span_contains_basic_attributes() {
let mock_collector = MockOtelCollector::start().await;
// Create TraceCollector pointing to mock with 500ms flush intervalc
std::env::set_var("OTEL_COLLECTOR_URL", format!("{}/v1/traces", mock_collector.address()));
std::env::set_var(
"OTEL_COLLECTOR_URL",
format!("{}/v1/traces", mock_collector.address()),
);
std::env::set_var("OTEL_TRACING_ENABLED", "true");
std::env::set_var("TRACE_FLUSH_INTERVAL_MS", "500");
let trace_collector = Arc::new(TraceCollector::new(Some(true)));
@ -102,7 +108,10 @@ async fn test_llm_span_contains_basic_attributes() {
let span = spans[0];
// Validate HTTP attributes
assert_eq!(get_string_attr(span, "http.method"), Some("POST"));
assert_eq!(get_string_attr(span, "http.target"), Some("/v1/chat/completions"));
assert_eq!(
get_string_attr(span, "http.target"),
Some("/v1/chat/completions")
);
// Validate LLM attributes
assert_eq!(get_string_attr(span, "llm.model"), Some("gpt-4o"));
@ -115,7 +124,10 @@ async fn test_llm_span_contains_basic_attributes() {
#[serial]
async fn test_llm_span_contains_tool_information() {
let mock_collector = MockOtelCollector::start().await;
std::env::set_var("OTEL_COLLECTOR_URL", format!("{}/v1/traces", mock_collector.address()));
std::env::set_var(
"OTEL_COLLECTOR_URL",
format!("{}/v1/traces", mock_collector.address()),
);
std::env::set_var("OTEL_TRACING_ENABLED", "true");
std::env::set_var("TRACE_FLUSH_INTERVAL_MS", "500");
let trace_collector = Arc::new(TraceCollector::new(Some(true)));
@ -144,19 +156,26 @@ async fn test_llm_span_contains_tool_information() {
assert!(tools.unwrap().contains("get_weather(...)"));
assert!(tools.unwrap().contains("search_web(...)"));
assert!(tools.unwrap().contains("calculate(...)"));
assert!(tools.unwrap().contains('\n'), "Tools should be newline-separated");
assert!(
tools.unwrap().contains('\n'),
"Tools should be newline-separated"
);
}
#[tokio::test]
#[serial]
async fn test_llm_span_contains_user_message_preview() {
let mock_collector = MockOtelCollector::start().await;
std::env::set_var("OTEL_COLLECTOR_URL", format!("{}/v1/traces", mock_collector.address()));
std::env::set_var(
"OTEL_COLLECTOR_URL",
format!("{}/v1/traces", mock_collector.address()),
);
std::env::set_var("OTEL_TRACING_ENABLED", "true");
std::env::set_var("TRACE_FLUSH_INTERVAL_MS", "500");
let trace_collector = Arc::new(TraceCollector::new(Some(true)));
let long_message = "This is a very long user message that should be truncated to 50 characters in the span";
let long_message =
"This is a very long user message that should be truncated to 50 characters in the span";
let preview = if long_message.len() > 50 {
format!("{}...", &long_message[..50])
} else {
@ -187,7 +206,10 @@ async fn test_llm_span_contains_user_message_preview() {
#[serial]
async fn test_llm_span_contains_time_to_first_token() {
let mock_collector = MockOtelCollector::start().await;
std::env::set_var("OTEL_COLLECTOR_URL", format!("{}/v1/traces", mock_collector.address()));
std::env::set_var(
"OTEL_COLLECTOR_URL",
format!("{}/v1/traces", mock_collector.address()),
);
std::env::set_var("OTEL_TRACING_ENABLED", "true");
std::env::set_var("TRACE_FLUSH_INTERVAL_MS", "500");
let trace_collector = Arc::new(TraceCollector::new(Some(true)));
@ -217,7 +239,10 @@ async fn test_llm_span_contains_time_to_first_token() {
#[serial]
async fn test_llm_span_contains_upstream_path() {
let mock_collector = MockOtelCollector::start().await;
std::env::set_var("OTEL_COLLECTOR_URL", format!("{}/v1/traces", mock_collector.address()));
std::env::set_var(
"OTEL_COLLECTOR_URL",
format!("{}/v1/traces", mock_collector.address()),
);
std::env::set_var("OTEL_TRACING_ENABLED", "true");
std::env::set_var("TRACE_FLUSH_INTERVAL_MS", "500");
let trace_collector = Arc::new(TraceCollector::new(Some(true)));
@ -241,7 +266,10 @@ async fn test_llm_span_contains_upstream_path() {
// Operation name should show the transformation
let name = span.get("name").and_then(|v| v.as_str());
assert!(name.is_some());
assert!(name.unwrap().contains(">>"), "Operation name should show path transformation");
assert!(
name.unwrap().contains(">>"),
"Operation name should show path transformation"
);
// Check upstream target attribute
let upstream = get_string_attr(span, "http.upstream_target");
@ -252,7 +280,10 @@ async fn test_llm_span_contains_upstream_path() {
#[serial]
async fn test_llm_span_multiple_services() {
let mock_collector = MockOtelCollector::start().await;
std::env::set_var("OTEL_COLLECTOR_URL", format!("{}/v1/traces", mock_collector.address()));
std::env::set_var(
"OTEL_COLLECTOR_URL",
format!("{}/v1/traces", mock_collector.address()),
);
std::env::set_var("OTEL_TRACING_ENABLED", "true");
std::env::set_var("TRACE_FLUSH_INTERVAL_MS", "500");
let trace_collector = Arc::new(TraceCollector::new(Some(true)));
@ -285,7 +316,10 @@ async fn test_tracing_disabled_produces_no_spans() {
let mock_collector = MockOtelCollector::start().await;
// Create TraceCollector with tracing DISABLED
std::env::set_var("OTEL_COLLECTOR_URL", format!("{}/v1/traces", mock_collector.address()));
std::env::set_var(
"OTEL_COLLECTOR_URL",
format!("{}/v1/traces", mock_collector.address()),
);
std::env::set_var("OTEL_TRACING_ENABLED", "false");
std::env::set_var("TRACE_FLUSH_INTERVAL_MS", "500");
let trace_collector = Arc::new(TraceCollector::new(Some(false)));
@ -300,5 +334,9 @@ async fn test_tracing_disabled_produces_no_spans() {
let payloads = mock_collector.get_traces().await;
let all_spans = extract_spans(&payloads);
assert_eq!(all_spans.len(), 0, "No spans should be captured when tracing is disabled");
assert_eq!(
all_spans.len(),
0,
"No spans should be captured when tracing is disabled"
);
}

View file

@ -161,13 +161,12 @@ impl TraceData {
}
pub fn new_with_service_name(service_name: String) -> Self {
let mut resource_attributes = Vec::new();
resource_attributes.push(Attribute {
let resource_attributes = vec![Attribute {
key: "service.name".to_string(),
value: AttributeValue {
string_value: Some(service_name),
},
});
}];
let resource = Resource {
attributes: resource_attributes,
@ -194,7 +193,9 @@ impl TraceData {
pub fn add_span(&mut self, span: Span) {
if self.resource_spans.is_empty() {
let resource = Resource { attributes: Vec::new() };
let resource = Resource {
attributes: Vec::new(),
};
let scope_span = ScopeSpan {
scope: Scope {
name: "default".to_string(),

View file

@ -66,7 +66,7 @@ impl ApiDefinition for AmazonBedrockApi {
/// Amazon Bedrock Converse request
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone)]
#[derive(Serialize, Deserialize, Debug, Clone, Default)]
pub struct ConverseRequest {
/// The model ID or ARN to invoke
pub model_id: String,
@ -91,7 +91,7 @@ pub struct ConverseRequest {
pub additional_model_response_field_paths: Option<Vec<String>>,
/// Performance configuration
#[serde(rename = "performanceConfig")]
pub performance_config: Option<PerformanceConfiguration>,
pub performance_config: Option<InferenceConfiguration>,
/// Prompt variables for Prompt management
#[serde(rename = "promptVariables")]
pub prompt_variables: Option<HashMap<String, PromptVariableValues>>,
@ -105,26 +105,6 @@ pub struct ConverseRequest {
pub stream: bool,
}
impl Default for ConverseRequest {
fn default() -> Self {
Self {
model_id: String::new(),
messages: None,
system: None,
inference_config: None,
tool_config: None,
guardrail_config: None,
additional_model_request_fields: None,
additional_model_response_field_paths: None,
performance_config: None,
prompt_variables: None,
request_metadata: None,
metadata: None,
stream: false,
}
}
}
/// Amazon Bedrock ConverseStream request (same structure as Converse)
pub type ConverseStreamRequest = ConverseRequest;
@ -204,8 +184,8 @@ impl ProviderRequest for ConverseRequest {
self.tool_config.as_ref()?.tools.as_ref().map(|tools| {
tools
.iter()
.filter_map(|tool| match tool {
Tool::ToolSpec { tool_spec } => Some(tool_spec.name.clone()),
.map(|tool| match tool {
Tool::ToolSpec { tool_spec } => tool_spec.name.clone(),
})
.collect()
})
@ -242,17 +222,14 @@ impl ProviderRequest for ConverseRequest {
// Add system messages if present
if let Some(system) = &self.system {
for sys_block in system {
match sys_block {
SystemContentBlock::Text { text } => {
openai_messages.push(Message {
role: Role::System,
content: MessageContent::Text(text.clone()),
name: None,
tool_calls: None,
tool_call_id: None,
});
}
_ => {} // Skip other system content types
if let SystemContentBlock::Text { text } = sys_block {
openai_messages.push(Message {
role: Role::System,
content: MessageContent::Text(text.clone()),
name: None,
tool_calls: None,
tool_call_id: None,
});
}
}
}
@ -266,7 +243,9 @@ impl ProviderRequest for ConverseRequest {
};
// Extract text from content blocks
let content = msg.content.iter()
let content = msg
.content
.iter()
.filter_map(|block| {
if let ContentBlock::Text { text } = block {
Some(text.clone())
@ -311,16 +290,14 @@ impl ProviderRequest for ConverseRequest {
_ => continue,
};
let content = if let crate::apis::openai::MessageContent::Text(text) = &msg.content {
vec![ContentBlock::Text { text: text.clone() }]
} else {
vec![]
};
let content =
if let crate::apis::openai::MessageContent::Text(text) = &msg.content {
vec![ContentBlock::Text { text: text.clone() }]
} else {
vec![]
};
bedrock_messages.push(crate::apis::amazon_bedrock::Message {
role,
content,
});
bedrock_messages.push(crate::apis::amazon_bedrock::Message { role, content });
}
_ => {}
}
@ -369,7 +346,7 @@ pub enum ConverseStreamEvent {
ContentBlockDelta(ContentBlockDeltaEvent),
ContentBlockStop(ContentBlockStopEvent),
MessageStop(MessageStopEvent),
Metadata(ConverseStreamMetadataEvent),
Metadata(Box<ConverseStreamMetadataEvent>),
// Error events
InternalServerException(BedrockException),
ModelStreamErrorException(BedrockException),
@ -1063,7 +1040,7 @@ impl TryFrom<&aws_smithy_eventstream::frame::DecodedFrame> for ConverseStreamEve
"metadata" => {
let event: ConverseStreamMetadataEvent =
serde_json::from_slice(payload).map_err(BedrockError::Serialization)?;
Ok(ConverseStreamEvent::Metadata(event))
Ok(ConverseStreamEvent::Metadata(Box::new(event)))
}
unknown => Err(BedrockError::Validation {
message: format!("Unknown event type: {}", unknown),
@ -1106,10 +1083,10 @@ impl TryFrom<&aws_smithy_eventstream::frame::DecodedFrame> for ConverseStreamEve
}
}
impl Into<String> for ConverseStreamEvent {
fn into(self) -> String {
let transformed_json = serde_json::to_string(&self).unwrap_or_default();
let event_type = match &self {
impl From<ConverseStreamEvent> for String {
fn from(val: ConverseStreamEvent) -> String {
let transformed_json = serde_json::to_string(&val).unwrap_or_default();
let event_type = match &val {
ConverseStreamEvent::MessageStart { .. } => "message_start",
ConverseStreamEvent::ContentBlockStart { .. } => "content_block_start",
ConverseStreamEvent::ContentBlockDelta { .. } => "content_block_delta",

File diff suppressed because one or more lines are too long

View file

@ -286,7 +286,6 @@ pub struct ImageUrl {
}
/// A single message in a chat conversation
/// A tool call made by the assistant
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct ToolCall {
@ -388,7 +387,7 @@ pub enum StaticContentType {
/// Chat completions API response
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone)]
#[derive(Serialize, Deserialize, Debug, Clone, Default)]
pub struct ChatCompletionsResponse {
pub id: String,
pub object: Option<String>,
@ -402,22 +401,6 @@ pub struct ChatCompletionsResponse {
pub metadata: Option<HashMap<String, Value>>,
}
impl Default for ChatCompletionsResponse {
fn default() -> Self {
ChatCompletionsResponse {
id: String::new(),
object: None,
created: 0,
model: String::new(),
choices: vec![],
usage: Usage::default(),
system_fingerprint: None,
service_tier: None,
metadata: None,
}
}
}
/// Finish reason for completion
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
@ -431,7 +414,7 @@ pub enum FinishReason {
/// Token usage information
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone)]
#[derive(Serialize, Deserialize, Debug, Clone, Default)]
pub struct Usage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
@ -440,18 +423,6 @@ pub struct Usage {
pub completion_tokens_details: Option<CompletionTokensDetails>,
}
impl Default for Usage {
fn default() -> Self {
Usage {
prompt_tokens: 0,
completion_tokens: 0,
total_tokens: 0,
prompt_tokens_details: None,
completion_tokens_details: None,
}
}
}
/// Detailed breakdown of prompt tokens
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone)]
@ -472,7 +443,7 @@ pub struct CompletionTokensDetails {
/// A single choice in the response
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone)]
#[derive(Serialize, Deserialize, Debug, Clone, Default)]
pub struct Choice {
pub index: u32,
pub message: ResponseMessage,
@ -480,17 +451,6 @@ pub struct Choice {
pub logprobs: Option<Value>,
}
impl Default for Choice {
fn default() -> Self {
Choice {
index: 0,
message: ResponseMessage::default(),
finish_reason: None,
logprobs: None,
}
}
}
// ============================================================================
// STREAMING API TYPES
// ============================================================================
@ -608,7 +568,6 @@ pub enum OpenAIError {
// ============================================================================
/// Trait Implementations
/// ===========================================================================
/// Parameterized conversion for ChatCompletionsRequest
impl TryFrom<&[u8]> for ChatCompletionsRequest {
type Error = OpenAIStreamError;
@ -721,7 +680,7 @@ impl ProviderRequest for ChatCompletionsRequest {
}
fn metadata(&self) -> &Option<HashMap<String, Value>> {
return &self.metadata;
&self.metadata
}
fn remove_metadata_key(&mut self, key: &str) -> bool {

View file

@ -1,7 +1,7 @@
use std::collections::HashMap;
use crate::providers::request::{ProviderRequest, ProviderRequestError};
use serde::{Deserialize, Serialize};
use serde_with::skip_serializing_none;
use crate::providers::request::{ProviderRequest, ProviderRequestError};
use std::collections::HashMap;
impl TryFrom<&[u8]> for ResponsesAPIRequest {
type Error = serde_json::Error;
@ -172,18 +172,14 @@ pub enum MessageRole {
#[serde(tag = "type", rename_all = "snake_case")]
pub enum InputContent {
/// Text input
InputText {
text: String,
},
InputText { text: String },
/// Image input via URL
InputImage {
image_url: String,
detail: Option<String>,
},
/// File input via URL
InputFile {
file_url: String,
},
InputFile { file_url: String },
/// Audio input
InputAudio {
data: Option<String>,
@ -222,9 +218,7 @@ pub struct TextConfig {
pub enum TextFormat {
Text,
JsonObject,
JsonSchema {
json_schema: serde_json::Value,
},
JsonSchema { json_schema: serde_json::Value },
}
/// Reasoning effort levels
@ -608,9 +602,7 @@ pub enum OutputContent {
transcript: Option<String>,
},
/// Refusal output
Refusal {
refusal: String,
},
Refusal { refusal: String },
}
/// Annotations for output text
@ -663,13 +655,9 @@ pub struct FileSearchResult {
#[serde(tag = "type", rename_all = "snake_case")]
pub enum CodeInterpreterOutput {
/// Text output
Text {
text: String,
},
Text { text: String },
/// Image output
Image {
image: String,
},
Image { image: String },
}
/// Response usage statistics
@ -951,9 +939,7 @@ pub enum ResponsesAPIStreamEvent {
},
/// Done event (end of stream)
Done {
sequence_number: i32,
},
Done { sequence_number: i32 },
}
// ============================================================================
@ -1052,12 +1038,19 @@ impl ProviderRequest for ResponsesAPIRequest {
MessageContent::Text(text) => text.clone(),
MessageContent::Items(content_items) => {
content_items.iter().fold(String::new(), |acc, content| {
acc + " " + &match content {
InputContent::InputText { text } => text.clone(),
InputContent::InputImage { .. } => "[Image]".to_string(),
InputContent::InputFile { .. } => "[File]".to_string(),
InputContent::InputAudio { .. } => "[Audio]".to_string(),
}
acc + " "
+ &match content {
InputContent::InputText { text } => text.clone(),
InputContent::InputImage { .. } => {
"[Image]".to_string()
}
InputContent::InputFile { .. } => {
"[File]".to_string()
}
InputContent::InputAudio { .. } => {
"[Audio]".to_string()
}
}
})
}
};
@ -1082,11 +1075,9 @@ impl ProviderRequest for ResponsesAPIRequest {
match &msg.content {
MessageContent::Text(text) => Some(text.clone()),
MessageContent::Items(content_items) => {
content_items.iter().find_map(|content| {
match content {
InputContent::InputText { text } => Some(text.clone()),
_ => None,
}
content_items.iter().find_map(|content| match content {
InputContent::InputText { text } => Some(text.clone()),
_ => None,
})
}
}
@ -1176,9 +1167,12 @@ impl ProviderRequest for ResponsesAPIRequest {
// Extract text from message content
let content = match &msg.content {
crate::apis::openai_responses::MessageContent::Text(text) => text.clone(),
crate::apis::openai_responses::MessageContent::Text(text) => {
text.clone()
}
crate::apis::openai_responses::MessageContent::Items(items) => {
items.iter()
items
.iter()
.filter_map(|c| {
if let InputContent::InputText { text } = c {
Some(text.clone())
@ -1214,7 +1208,8 @@ impl ProviderRequest for ResponsesAPIRequest {
fn set_messages(&mut self, messages: &[crate::apis::openai::Message]) {
// For ResponsesAPI, we need to convert messages back to input format
// Extract system messages as instructions
let system_text = messages.iter()
let system_text = messages
.iter()
.filter(|msg| msg.role == crate::apis::openai::Role::System)
.filter_map(|msg| {
if let crate::apis::openai::MessageContent::Text(text) = &msg.content {
@ -1233,23 +1228,27 @@ impl ProviderRequest for ResponsesAPIRequest {
// Convert user/assistant messages to InputParam
// For simplicity, we'll use the last user message as the input
// or combine all non-system messages
let input_messages: Vec<_> = messages.iter()
let input_messages: Vec<_> = messages
.iter()
.filter(|msg| msg.role != crate::apis::openai::Role::System)
.collect();
if !input_messages.is_empty() {
// If there's only one message, use Text format
if input_messages.len() == 1 {
if let crate::apis::openai::MessageContent::Text(text) = &input_messages[0].content {
if let crate::apis::openai::MessageContent::Text(text) = &input_messages[0].content
{
self.input = crate::apis::openai_responses::InputParam::Text(text.clone());
}
} else {
// Multiple messages - combine them as text for now
// A more sophisticated approach would use InputParam::Items
let combined_text = input_messages.iter()
let combined_text = input_messages
.iter()
.filter_map(|msg| {
if let crate::apis::openai::MessageContent::Text(text) = &msg.content {
Some(format!("{}: {}",
Some(format!(
"{}: {}",
match msg.role {
crate::apis::openai::Role::User => "User",
crate::apis::openai::Role::Assistant => "Assistant",
@ -1274,10 +1273,10 @@ impl ProviderRequest for ResponsesAPIRequest {
// Into<String> Implementation for SSE Formatting
// ============================================================================
impl Into<String> for ResponsesAPIStreamEvent {
fn into(self) -> String {
let transformed_json = serde_json::to_string(&self).unwrap_or_default();
let event_type = match &self {
impl From<ResponsesAPIStreamEvent> for String {
fn from(val: ResponsesAPIStreamEvent) -> Self {
let transformed_json = serde_json::to_string(&val).unwrap_or_default();
let event_type = match &val {
ResponsesAPIStreamEvent::ResponseCreated { .. } => "response.created",
ResponsesAPIStreamEvent::ResponseInProgress { .. } => "response.in_progress",
ResponsesAPIStreamEvent::ResponseCompleted { .. } => "response.completed",
@ -1365,10 +1364,10 @@ impl crate::providers::streaming_response::ProviderStreamResponse for ResponsesA
fn role(&self) -> Option<&str> {
match self {
ResponsesAPIStreamEvent::ResponseOutputItemDone { item, .. } => match item {
OutputItem::Message { role, .. } => Some(role.as_str()),
_ => None,
},
ResponsesAPIStreamEvent::ResponseOutputItemDone {
item: OutputItem::Message { role, .. },
..
} => Some(role.as_str()),
_ => None,
}
}

View file

@ -34,10 +34,7 @@ where
}
pub fn decode_frame(&mut self) -> Option<DecodedFrame> {
match self.decoder.decode_frame(&mut self.buffer) {
Ok(frame) => Some(frame),
Err(_e) => None, // Fatal decode error
}
self.decoder.decode_frame(&mut self.buffer).ok()
}
pub fn buffer_mut(&mut self) -> &mut B {

View file

@ -1,5 +1,5 @@
use crate::apis::streaming_shapes::sse::{SseEvent, SseStreamBufferTrait};
use crate::apis::anthropic::MessagesStreamEvent;
use crate::apis::streaming_shapes::sse::{SseEvent, SseStreamBufferTrait};
use crate::providers::streaming_response::ProviderStreamResponseType;
use std::collections::HashSet;
@ -31,6 +31,12 @@ pub struct AnthropicMessagesStreamBuffer {
model: Option<String>,
}
impl Default for AnthropicMessagesStreamBuffer {
fn default() -> Self {
Self::new()
}
}
impl AnthropicMessagesStreamBuffer {
pub fn new() -> Self {
Self {
@ -154,7 +160,8 @@ impl SseStreamBufferTrait for AnthropicMessagesStreamBuffer {
// Inject message_start if needed
if !self.message_started {
let model = self.model.as_deref().unwrap_or("unknown");
let message_start = AnthropicMessagesStreamBuffer::create_message_start_event(model);
let message_start =
AnthropicMessagesStreamBuffer::create_message_start_event(model);
self.buffered_events.push(message_start);
self.message_started = true;
}
@ -169,7 +176,8 @@ impl SseStreamBufferTrait for AnthropicMessagesStreamBuffer {
// Inject message_start if needed
if !self.message_started {
let model = self.model.as_deref().unwrap_or("unknown");
let message_start = AnthropicMessagesStreamBuffer::create_message_start_event(model);
let message_start =
AnthropicMessagesStreamBuffer::create_message_start_event(model);
self.buffered_events.push(message_start);
self.message_started = true;
}
@ -177,7 +185,8 @@ impl SseStreamBufferTrait for AnthropicMessagesStreamBuffer {
// Check if ContentBlockStart was sent for this index
if !self.has_content_block_start_been_sent(index) {
// Inject ContentBlockStart before delta
let content_block_start = AnthropicMessagesStreamBuffer::create_content_block_start_event();
let content_block_start =
AnthropicMessagesStreamBuffer::create_content_block_start_event();
self.buffered_events.push(content_block_start);
self.set_content_block_start_sent(index);
self.needs_content_block_stop = true;
@ -189,7 +198,8 @@ impl SseStreamBufferTrait for AnthropicMessagesStreamBuffer {
MessagesStreamEvent::MessageDelta { usage, .. } => {
// Inject ContentBlockStop before message_delta
if self.needs_content_block_stop {
let content_block_stop = AnthropicMessagesStreamBuffer::create_content_block_stop_event();
let content_block_stop =
AnthropicMessagesStreamBuffer::create_content_block_stop_event();
self.buffered_events.push(content_block_stop);
self.needs_content_block_stop = false;
}
@ -199,10 +209,10 @@ impl SseStreamBufferTrait for AnthropicMessagesStreamBuffer {
if let Some(last_event) = self.buffered_events.last_mut() {
if let Some(ProviderStreamResponseType::MessagesStreamEvent(
MessagesStreamEvent::MessageDelta {
usage: last_usage,
..
}
)) = &mut last_event.provider_stream_response {
usage: last_usage, ..
},
)) = &mut last_event.provider_stream_response
{
// Merge: take stop_reason from first, usage from second (if non-zero)
if usage.input_tokens > 0 || usage.output_tokens > 0 {
*last_usage = usage.clone();
@ -243,7 +253,7 @@ impl SseStreamBufferTrait for AnthropicMessagesStreamBuffer {
}
}
fn into_bytes(&mut self) -> Vec<u8> {
fn to_bytes(&mut self) -> Vec<u8> {
// Convert all accumulated events to bytes and clear buffer
// NOTE: We do NOT inject ContentBlockStop here because it's injected when we see MessageDelta
// or MessageStop. Injecting it here causes premature ContentBlockStop in the middle of streaming.
@ -276,10 +286,10 @@ impl SseStreamBufferTrait for AnthropicMessagesStreamBuffer {
#[cfg(test)]
mod tests {
use super::*;
use crate::clients::{SupportedAPIsFromClient, SupportedUpstreamAPIs};
use crate::apis::anthropic::AnthropicApi;
use crate::apis::openai::OpenAIApi;
use crate::apis::streaming_shapes::sse::SseStreamIter;
use crate::clients::{SupportedAPIsFromClient, SupportedUpstreamAPIs};
#[test]
fn test_openai_to_anthropic_complete_transformation() {
@ -308,11 +318,12 @@ data: [DONE]"#;
let mut buffer = AnthropicMessagesStreamBuffer::new();
for raw_event in stream_iter {
let transformed_event = SseEvent::try_from((raw_event, &client_api, &upstream_api)).unwrap();
let transformed_event =
SseEvent::try_from((raw_event, &client_api, &upstream_api)).unwrap();
buffer.add_transformed_event(transformed_event);
}
let output_bytes = buffer.into_bytes();
let output_bytes = buffer.to_bytes();
let output = String::from_utf8_lossy(&output_bytes);
println!("\nTRANSFORMED OUTPUT (Anthropic Messages API):");
@ -321,25 +332,54 @@ data: [DONE]"#;
// Assertions
assert!(!output_bytes.is_empty(), "Should have output");
assert!(output.contains("event: message_start"), "Should have message_start");
assert!(output.contains("event: content_block_start"), "Should have content_block_start (injected)");
assert!(
output.contains("event: message_start"),
"Should have message_start"
);
assert!(
output.contains("event: content_block_start"),
"Should have content_block_start (injected)"
);
let delta_count = output.matches("event: content_block_delta").count();
assert_eq!(delta_count, 2, "Should have exactly 2 content_block_delta events");
assert_eq!(
delta_count, 2,
"Should have exactly 2 content_block_delta events"
);
// Verify both pieces of content are present
assert!(output.contains("\"text\":\"Hello\""), "Should have first content delta 'Hello'");
assert!(output.contains("\"text\":\" world\""), "Should have second content delta ' world'");
assert!(
output.contains("\"text\":\"Hello\""),
"Should have first content delta 'Hello'"
);
assert!(
output.contains("\"text\":\" world\""),
"Should have second content delta ' world'"
);
assert!(output.contains("event: content_block_stop"), "Should have content_block_stop (injected)");
assert!(output.contains("event: message_delta"), "Should have message_delta");
assert!(output.contains("event: message_stop"), "Should have message_stop");
assert!(
output.contains("event: content_block_stop"),
"Should have content_block_stop (injected)"
);
assert!(
output.contains("event: message_delta"),
"Should have message_delta"
);
assert!(
output.contains("event: message_stop"),
"Should have message_stop"
);
println!("\nVALIDATION SUMMARY:");
println!("{}", "-".repeat(80));
println!("✓ Complete transformation: OpenAI ChatCompletions → Anthropic Messages API");
println!("✓ Injected lifecycle events: message_start, content_block_start, content_block_stop");
println!("✓ Content deltas: {} events (BOTH 'Hello' and ' world' preserved!)", delta_count);
println!(
"✓ Injected lifecycle events: message_start, content_block_start, content_block_stop"
);
println!(
"✓ Content deltas: {} events (BOTH 'Hello' and ' world' preserved!)",
delta_count
);
println!("✓ Complete stream with message_stop");
println!("✓ Proper Anthropic protocol sequencing\n");
}
@ -369,11 +409,12 @@ data: {"id":"chatcmpl-456","object":"chat.completion.chunk","created":1234567890
let mut buffer = AnthropicMessagesStreamBuffer::new();
for raw_event in stream_iter {
let transformed_event = SseEvent::try_from((raw_event, &client_api, &upstream_api)).unwrap();
let transformed_event =
SseEvent::try_from((raw_event, &client_api, &upstream_api)).unwrap();
buffer.add_transformed_event(transformed_event);
}
let output_bytes = buffer.into_bytes();
let output_bytes = buffer.to_bytes();
let output = String::from_utf8_lossy(&output_bytes);
println!("\nTRANSFORMED OUTPUT (Anthropic Messages API):");
@ -382,31 +423,61 @@ data: {"id":"chatcmpl-456","object":"chat.completion.chunk","created":1234567890
// Assertions
assert!(!output_bytes.is_empty(), "Should have output");
assert!(output.contains("event: message_start"), "Should have message_start");
assert!(output.contains("event: content_block_start"), "Should have content_block_start (injected)");
assert!(
output.contains("event: message_start"),
"Should have message_start"
);
assert!(
output.contains("event: content_block_start"),
"Should have content_block_start (injected)"
);
let delta_count = output.matches("event: content_block_delta").count();
assert_eq!(delta_count, 3, "Should have exactly 3 content_block_delta events");
assert_eq!(
delta_count, 3,
"Should have exactly 3 content_block_delta events"
);
// Verify all three pieces of content are present
assert!(output.contains("\"text\":\"The weather\""), "Should have first content delta");
assert!(output.contains("\"text\":\" in San Francisco\""), "Should have second content delta");
assert!(output.contains("\"text\":\" is\""), "Should have third content delta");
assert!(
output.contains("\"text\":\"The weather\""),
"Should have first content delta"
);
assert!(
output.contains("\"text\":\" in San Francisco\""),
"Should have second content delta"
);
assert!(
output.contains("\"text\":\" is\""),
"Should have third content delta"
);
// For partial streams (no finish_reason, no [DONE]), we do NOT inject content_block_stop
// because the stream may continue. This is correct behavior - only inject lifecycle events
// when we have explicit signals from upstream (finish_reason, [DONE], etc.)
assert!(!output.contains("event: content_block_stop"), "Should NOT have content_block_stop for partial stream");
assert!(
!output.contains("event: content_block_stop"),
"Should NOT have content_block_stop for partial stream"
);
// Should NOT have completion events
assert!(!output.contains("event: message_delta"), "Should NOT have message_delta");
assert!(!output.contains("event: message_stop"), "Should NOT have message_stop");
assert!(
!output.contains("event: message_delta"),
"Should NOT have message_delta"
);
assert!(
!output.contains("event: message_stop"),
"Should NOT have message_stop"
);
println!("\nVALIDATION SUMMARY:");
println!("{}", "-".repeat(80));
println!("✓ Partial transformation: OpenAI → Anthropic (stream interrupted)");
println!("✓ Injected: message_start, content_block_start at beginning");
println!("✓ Incremental deltas: {} events (ALL content preserved!)", delta_count);
println!(
"✓ Incremental deltas: {} events (ALL content preserved!)",
delta_count
);
println!("✓ NO completion events (partial stream, no [DONE])");
println!("✓ Buffer maintains Anthropic protocol for active streams\n");
}
@ -452,11 +523,12 @@ data: [DONE]"#;
let mut buffer = AnthropicMessagesStreamBuffer::new();
for raw_event in stream_iter {
let transformed_event = SseEvent::try_from((raw_event, &client_api, &upstream_api)).unwrap();
let transformed_event =
SseEvent::try_from((raw_event, &client_api, &upstream_api)).unwrap();
buffer.add_transformed_event(transformed_event);
}
let output_bytes = buffer.into_bytes();
let output_bytes = buffer.to_bytes();
let output = String::from_utf8_lossy(&output_bytes);
println!("\nTRANSFORMED OUTPUT (Anthropic Messages API):");
@ -467,32 +539,71 @@ data: [DONE]"#;
assert!(!output_bytes.is_empty(), "Should have output");
// Should have lifecycle events (injected by buffer)
assert!(output.contains("event: message_start"), "Should have message_start (injected)");
assert!(output.contains("event: content_block_start"), "Should have content_block_start");
assert!(output.contains("event: content_block_stop"), "Should have content_block_stop (injected)");
assert!(output.contains("event: message_delta"), "Should have message_delta");
assert!(output.contains("event: message_stop"), "Should have message_stop");
assert!(
output.contains("event: message_start"),
"Should have message_start (injected)"
);
assert!(
output.contains("event: content_block_start"),
"Should have content_block_start"
);
assert!(
output.contains("event: content_block_stop"),
"Should have content_block_stop (injected)"
);
assert!(
output.contains("event: message_delta"),
"Should have message_delta"
);
assert!(
output.contains("event: message_stop"),
"Should have message_stop"
);
// Should have tool_use content block
assert!(output.contains("\"type\":\"tool_use\""), "Should have tool_use type");
assert!(output.contains("\"name\":\"get_weather\""), "Should have correct function name");
assert!(output.contains("\"id\":\"call_2Uzw0AEZQeOex2CP2TKjcLKc\""), "Should have correct tool call ID");
assert!(
output.contains("\"type\":\"tool_use\""),
"Should have tool_use type"
);
assert!(
output.contains("\"name\":\"get_weather\""),
"Should have correct function name"
);
assert!(
output.contains("\"id\":\"call_2Uzw0AEZQeOex2CP2TKjcLKc\""),
"Should have correct tool call ID"
);
// Count input_json_delta events - should match the number of argument chunks
let delta_count = output.matches("event: content_block_delta").count();
assert!(delta_count >= 8, "Should have at least 8 input_json_delta events");
assert!(
delta_count >= 8,
"Should have at least 8 input_json_delta events"
);
// Verify argument deltas are present
assert!(output.contains("\"type\":\"input_json_delta\""), "Should have input_json_delta type");
assert!(output.contains("\"partial_json\":"), "Should have partial_json field");
assert!(
output.contains("\"type\":\"input_json_delta\""),
"Should have input_json_delta type"
);
assert!(
output.contains("\"partial_json\":"),
"Should have partial_json field"
);
// Verify the accumulated arguments contain the location
assert!(output.contains("San"), "Arguments should contain 'San'");
assert!(output.contains("Francisco"), "Arguments should contain 'Francisco'");
assert!(
output.contains("Francisco"),
"Arguments should contain 'Francisco'"
);
assert!(output.contains("CA"), "Arguments should contain 'CA'");
// Verify stop reason is tool_use
assert!(output.contains("\"stop_reason\":\"tool_use\""), "Should have stop_reason as tool_use");
assert!(
output.contains("\"stop_reason\":\"tool_use\""),
"Should have stop_reason as tool_use"
);
println!("\nVALIDATION SUMMARY:");
println!("{}", "-".repeat(80));

View file

@ -6,6 +6,12 @@ pub struct OpenAIChatCompletionsStreamBuffer {
buffered_events: Vec<SseEvent>,
}
impl Default for OpenAIChatCompletionsStreamBuffer {
fn default() -> Self {
Self::new()
}
}
impl OpenAIChatCompletionsStreamBuffer {
pub fn new() -> Self {
Self {
@ -26,7 +32,7 @@ impl SseStreamBufferTrait for OpenAIChatCompletionsStreamBuffer {
self.buffered_events.push(event);
}
fn into_bytes(&mut self) -> Vec<u8> {
fn to_bytes(&mut self) -> Vec<u8> {
// No finalization needed for OpenAI Chat Completions
// The [DONE] marker is already handled by the transformation layer
let mut buffer = Vec::new();

View file

@ -1,7 +1,7 @@
pub mod sse;
pub mod sse_chunk_processor;
pub mod amazon_bedrock_binary_frame;
pub mod anthropic_streaming_buffer;
pub mod chat_completions_streaming_buffer;
pub mod passthrough_streaming_buffer;
pub mod responses_api_streaming_buffer;
pub mod sse;
pub mod sse_chunk_processor;

View file

@ -6,6 +6,12 @@ pub struct PassthroughStreamBuffer {
buffered_events: Vec<SseEvent>,
}
impl Default for PassthroughStreamBuffer {
fn default() -> Self {
Self::new()
}
}
impl PassthroughStreamBuffer {
pub fn new() -> Self {
Self {
@ -30,7 +36,7 @@ impl SseStreamBufferTrait for PassthroughStreamBuffer {
self.buffered_events.push(event);
}
fn into_bytes(&mut self) -> Vec<u8> {
fn to_bytes(&mut self) -> Vec<u8> {
// No finalization needed for passthrough - just convert accumulated events to bytes
let mut buffer = Vec::new();
for event in self.buffered_events.drain(..) {
@ -44,7 +50,7 @@ impl SseStreamBufferTrait for PassthroughStreamBuffer {
#[cfg(test)]
mod tests {
use crate::apis::streaming_shapes::passthrough_streaming_buffer::PassthroughStreamBuffer;
use crate::apis::streaming_shapes::sse::{SseStreamIter, SseStreamBufferTrait};
use crate::apis::streaming_shapes::sse::{SseStreamBufferTrait, SseStreamIter};
#[test]
fn test_chat_completions_passthrough_buffer() {
@ -73,7 +79,7 @@ mod tests {
buffer.add_transformed_event(event);
}
let output_bytes = buffer.into_bytes();
let output_bytes = buffer.to_bytes();
let output = String::from_utf8_lossy(&output_bytes);
println!("\nTRANSFORMED OUTPUT (ChatCompletions - Passthrough):");
@ -84,7 +90,11 @@ mod tests {
assert!(!output_bytes.is_empty());
assert!(output.contains("chatcmpl-123"));
assert!(output.contains("[DONE]"));
assert_eq!(raw_input.trim(), output.trim(), "Passthrough should preserve input");
assert_eq!(
raw_input.trim(),
output.trim(),
"Passthrough should preserve input"
);
println!("\nVALIDATION SUMMARY:");
println!("{}", "-".repeat(80));

View file

@ -1,10 +1,10 @@
use std::collections::HashMap;
use log::debug;
use crate::apis::openai_responses::{
ResponsesAPIStreamEvent, ResponsesAPIResponse, OutputItem, OutputItemStatus,
ResponseStatus, TextConfig, TextFormat, Reasoning,
OutputItem, OutputItemStatus, Reasoning, ResponseStatus, ResponsesAPIResponse,
ResponsesAPIStreamEvent, TextConfig, TextFormat,
};
use crate::apis::streaming_shapes::sse::{SseEvent, SseStreamBufferTrait};
use log::debug;
use std::collections::HashMap;
/// Helper to convert ResponseAPIStreamEvent to SseEvent
fn event_to_sse(event: ResponsesAPIStreamEvent) -> SseEvent {
@ -16,10 +16,17 @@ fn event_to_sse(event: ResponsesAPIStreamEvent) -> SseEvent {
ResponsesAPIStreamEvent::ResponseOutputItemDone { .. } => "response.output_item.done",
ResponsesAPIStreamEvent::ResponseOutputTextDelta { .. } => "response.output_text.delta",
ResponsesAPIStreamEvent::ResponseOutputTextDone { .. } => "response.output_text.done",
ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDelta { .. } => "response.function_call_arguments.delta",
ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDone { .. } => "response.function_call_arguments.done",
ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDelta { .. } => {
"response.function_call_arguments.delta"
}
ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDone { .. } => {
"response.function_call_arguments.done"
}
unknown => {
debug!("Unknown ResponsesAPIStreamEvent type encountered: {:?}", unknown);
debug!(
"Unknown ResponsesAPIStreamEvent type encountered: {:?}",
unknown
);
"unknown"
}
};
@ -85,6 +92,12 @@ pub struct ResponsesAPIStreamBuffer {
buffered_events: Vec<SseEvent>,
}
impl Default for ResponsesAPIStreamBuffer {
fn default() -> Self {
Self::new()
}
}
impl ResponsesAPIStreamBuffer {
pub fn new() -> Self {
Self {
@ -112,7 +125,11 @@ impl ResponsesAPIStreamBuffer {
}
fn generate_item_id(prefix: &str) -> String {
format!("{}_{}", prefix, uuid::Uuid::new_v4().to_string().replace("-", ""))
format!(
"{}_{}",
prefix,
uuid::Uuid::new_v4().to_string().replace("-", "")
)
}
fn get_or_create_item_id(&mut self, output_index: i32, prefix: &str) -> String {
@ -160,7 +177,13 @@ impl ResponsesAPIStreamBuffer {
}
/// Create output_item.added event for tool call
fn create_tool_call_added_event(&mut self, output_index: i32, item_id: &str, call_id: &str, name: &str) -> SseEvent {
fn create_tool_call_added_event(
&mut self,
output_index: i32,
item_id: &str,
call_id: &str,
name: &str,
) -> SseEvent {
let event = ResponsesAPIStreamEvent::ResponseOutputItemAdded {
output_index,
item: OutputItem::FunctionCall {
@ -237,9 +260,15 @@ impl ResponsesAPIStreamBuffer {
// Emit done events for all accumulated content
// Text content done events
let text_items: Vec<_> = self.text_content.iter().map(|(id, content)| (id.clone(), content.clone())).collect();
let text_items: Vec<_> = self
.text_content
.iter()
.map(|(id, content)| (id.clone(), content.clone()))
.collect();
for (item_id, content) in text_items {
let output_index = self.output_items_added.iter()
let output_index = self
.output_items_added
.iter()
.find(|(_, id)| **id == item_id)
.map(|(idx, _)| *idx)
.unwrap_or(0);
@ -270,9 +299,15 @@ impl ResponsesAPIStreamBuffer {
}
// Function call done events
let func_items: Vec<_> = self.function_arguments.iter().map(|(id, args)| (id.clone(), args.clone())).collect();
let func_items: Vec<_> = self
.function_arguments
.iter()
.map(|(id, args)| (id.clone(), args.clone()))
.collect();
for (item_id, arguments) in func_items {
let output_index = self.output_items_added.iter()
let output_index = self
.output_items_added
.iter()
.find(|(_, id)| **id == item_id)
.map(|(idx, _)| *idx)
.unwrap_or(0);
@ -286,9 +321,16 @@ impl ResponsesAPIStreamBuffer {
};
events.push(event_to_sse(args_done_event));
let (call_id, name) = self.tool_call_metadata.get(&output_index)
let (call_id, name) = self
.tool_call_metadata
.get(&output_index)
.cloned()
.unwrap_or_else(|| (format!("call_{}", uuid::Uuid::new_v4()), "unknown".to_string()));
.unwrap_or_else(|| {
(
format!("call_{}", uuid::Uuid::new_v4()),
"unknown".to_string(),
)
});
let seq2 = self.next_sequence_number();
let item_done_event = ResponsesAPIStreamEvent::ResponseOutputItemDone {
@ -315,9 +357,16 @@ impl ResponsesAPIStreamBuffer {
if let Some(item_id) = self.output_items_added.get(&output_index) {
// Check if this is a function call
if let Some(arguments) = self.function_arguments.get(item_id) {
let (call_id, name) = self.tool_call_metadata.get(&output_index)
let (call_id, name) = self
.tool_call_metadata
.get(&output_index)
.cloned()
.unwrap_or_else(|| (format!("call_{}", uuid::Uuid::new_v4()), "unknown".to_string()));
.unwrap_or_else(|| {
(
format!("call_{}", uuid::Uuid::new_v4()),
"unknown".to_string(),
)
});
output_items.push(OutputItem::FunctionCall {
id: item_id.clone(),
@ -397,9 +446,9 @@ impl SseStreamBufferTrait for ResponsesAPIStreamBuffer {
let mut events = Vec::new();
// Capture upstream metadata from ResponseCreated or ResponseInProgress if present
match stream_event {
ResponsesAPIStreamEvent::ResponseCreated { response, .. } |
ResponsesAPIStreamEvent::ResponseInProgress { response, .. } => {
match stream_event.as_ref() {
ResponsesAPIStreamEvent::ResponseCreated { response, .. }
| ResponsesAPIStreamEvent::ResponseInProgress { response, .. } => {
if self.upstream_response_metadata.is_none() {
// Store the full upstream response as our metadata template
self.upstream_response_metadata = Some(response.clone());
@ -418,11 +467,16 @@ impl SseStreamBufferTrait for ResponsesAPIStreamBuffer {
if !self.created_emitted {
// Initialize metadata from first event if needed
if self.response_id.is_none() {
self.response_id = Some(format!("resp_{}", uuid::Uuid::new_v4().to_string().replace("-", "")));
self.created_at = Some(std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs() as i64);
self.response_id = Some(format!(
"resp_{}",
uuid::Uuid::new_v4().to_string().replace("-", "")
));
self.created_at = Some(
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs() as i64,
);
self.model = Some("unknown".to_string()); // Will be set by caller if available
}
@ -436,58 +490,95 @@ impl SseStreamBufferTrait for ResponsesAPIStreamBuffer {
}
// Process the delta event
match stream_event {
ResponsesAPIStreamEvent::ResponseOutputTextDelta { output_index, delta, .. } => {
match stream_event.as_ref() {
ResponsesAPIStreamEvent::ResponseOutputTextDelta {
output_index,
delta,
..
} => {
let item_id = self.get_or_create_item_id(*output_index, "msg");
// Emit output_item.added if this is the first time we see this output index
if !self.output_items_added.contains_key(output_index) {
self.output_items_added.insert(*output_index, item_id.clone());
self.output_items_added
.insert(*output_index, item_id.clone());
events.push(self.create_output_item_added_event(*output_index, &item_id));
}
// Accumulate text content
self.text_content.entry(item_id.clone())
self.text_content
.entry(item_id.clone())
.and_modify(|content| content.push_str(delta))
.or_insert_with(|| delta.clone());
// Emit text delta with filled-in item_id and sequence_number
let mut delta_event = stream_event.clone();
if let ResponsesAPIStreamEvent::ResponseOutputTextDelta { item_id: ref mut id, sequence_number: ref mut seq, .. } = delta_event {
let mut delta_event = stream_event.as_ref().clone();
if let ResponsesAPIStreamEvent::ResponseOutputTextDelta {
item_id: ref mut id,
sequence_number: ref mut seq,
..
} = &mut delta_event
{
*id = item_id;
*seq = self.next_sequence_number();
}
events.push(event_to_sse(delta_event));
}
ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDelta { output_index, delta, call_id, name, .. } => {
ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDelta {
output_index,
delta,
call_id,
name,
..
} => {
let item_id = self.get_or_create_item_id(*output_index, "fc");
// Store metadata if provided (from initial tool call event)
if let (Some(cid), Some(n)) = (call_id, name) {
self.tool_call_metadata.insert(*output_index, (cid.clone(), n.clone()));
self.tool_call_metadata
.insert(*output_index, (cid.clone(), n.clone()));
}
// Emit output_item.added if this is the first time we see this tool call
if !self.output_items_added.contains_key(output_index) {
self.output_items_added.insert(*output_index, item_id.clone());
self.output_items_added
.insert(*output_index, item_id.clone());
// For tool calls, we need call_id and name from metadata
// These should now be populated from the event itself
let (call_id, name) = self.tool_call_metadata.get(output_index)
let (call_id, name) = self
.tool_call_metadata
.get(output_index)
.cloned()
.unwrap_or_else(|| (format!("call_{}", uuid::Uuid::new_v4()), "unknown".to_string()));
.unwrap_or_else(|| {
(
format!("call_{}", uuid::Uuid::new_v4()),
"unknown".to_string(),
)
});
events.push(self.create_tool_call_added_event(*output_index, &item_id, &call_id, &name));
events.push(self.create_tool_call_added_event(
*output_index,
&item_id,
&call_id,
&name,
));
}
// Accumulate function arguments
self.function_arguments.entry(item_id.clone())
self.function_arguments
.entry(item_id.clone())
.and_modify(|args| args.push_str(delta))
.or_insert_with(|| delta.clone());
// Emit function call arguments delta with filled-in item_id and sequence_number
let mut delta_event = stream_event.clone();
if let ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDelta { item_id: ref mut id, sequence_number: ref mut seq, .. } = delta_event {
let mut delta_event = stream_event.as_ref().clone();
if let ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDelta {
item_id: ref mut id,
sequence_number: ref mut seq,
..
} = &mut delta_event
{
*id = item_id;
*seq = self.next_sequence_number();
}
@ -495,7 +586,7 @@ impl SseStreamBufferTrait for ResponsesAPIStreamBuffer {
}
_ => {
// For other event types, just pass through with sequence number
let other_event = stream_event.clone();
let other_event = stream_event.as_ref().clone();
// TODO: Add sequence number to other event types if needed
events.push(event_to_sse(other_event));
}
@ -505,8 +596,7 @@ impl SseStreamBufferTrait for ResponsesAPIStreamBuffer {
self.buffered_events.extend(events);
}
fn into_bytes(&mut self) -> Vec<u8> {
fn to_bytes(&mut self) -> Vec<u8> {
// For Responses API, we need special handling:
// - Most events are already in buffered_events from add_transformed_event
// - We should NOT finalize here - finalization happens when we detect [DONE] or end of stream
@ -525,9 +615,9 @@ impl SseStreamBufferTrait for ResponsesAPIStreamBuffer {
#[cfg(test)]
mod tests {
use super::*;
use crate::clients::{SupportedAPIsFromClient, SupportedUpstreamAPIs};
use crate::apis::openai::OpenAIApi;
use crate::apis::streaming_shapes::sse::SseStreamIter;
use crate::clients::{SupportedAPIsFromClient, SupportedUpstreamAPIs};
#[test]
fn test_chat_completions_to_responses_api_transformation() {
@ -557,11 +647,12 @@ mod tests {
for raw_event in stream_iter {
// Transform the event using the client/upstream APIs
let transformed_event = SseEvent::try_from((raw_event, &client_api, &upstream_api)).unwrap();
let transformed_event =
SseEvent::try_from((raw_event, &client_api, &upstream_api)).unwrap();
buffer.add_transformed_event(transformed_event);
}
let output_bytes = buffer.into_bytes();
let output_bytes = buffer.to_bytes();
let output = String::from_utf8_lossy(&output_bytes);
println!("\nTRANSFORMED OUTPUT (ResponsesAPI):");
@ -570,13 +661,34 @@ mod tests {
// Assertions
assert!(!output_bytes.is_empty(), "Should have output");
assert!(output.contains("response.created"), "Should have response.created");
assert!(output.contains("response.in_progress"), "Should have response.in_progress");
assert!(output.contains("response.output_item.added"), "Should have output_item.added");
assert!(output.contains("response.output_text.delta"), "Should have text deltas");
assert!(output.contains("response.output_text.done"), "Should have text.done");
assert!(output.contains("response.output_item.done"), "Should have output_item.done");
assert!(output.contains("response.completed"), "Should have response.completed");
assert!(
output.contains("response.created"),
"Should have response.created"
);
assert!(
output.contains("response.in_progress"),
"Should have response.in_progress"
);
assert!(
output.contains("response.output_item.added"),
"Should have output_item.added"
);
assert!(
output.contains("response.output_text.delta"),
"Should have text deltas"
);
assert!(
output.contains("response.output_text.done"),
"Should have text.done"
);
assert!(
output.contains("response.output_item.done"),
"Should have output_item.done"
);
assert!(
output.contains("response.completed"),
"Should have response.completed"
);
println!("\nVALIDATION SUMMARY:");
println!("{}", "-".repeat(80));
@ -616,7 +728,7 @@ mod tests {
buffer.add_transformed_event(transformed);
}
let output_bytes = buffer.into_bytes();
let output_bytes = buffer.to_bytes();
let output = String::from_utf8_lossy(&output_bytes);
println!("\nTRANSFORMED OUTPUT (ResponsesAPI):");
@ -624,24 +736,55 @@ mod tests {
println!("{}", output);
// Assertions
assert!(output.contains("response.created"), "Should have response.created");
assert!(output.contains("response.in_progress"), "Should have response.in_progress");
assert!(output.contains("response.output_item.added"), "Should have output_item.added");
assert!(output.contains("\"type\":\"function_call\""), "Should be function_call type");
assert!(output.contains("\"name\":\"get_weather\""), "Should have function name");
assert!(output.contains("\"call_id\":\"call_mD5ggLKk3SMKGPFqFdcpKg6q\""), "Should have correct call_id");
assert!(
output.contains("response.created"),
"Should have response.created"
);
assert!(
output.contains("response.in_progress"),
"Should have response.in_progress"
);
assert!(
output.contains("response.output_item.added"),
"Should have output_item.added"
);
assert!(
output.contains("\"type\":\"function_call\""),
"Should be function_call type"
);
assert!(
output.contains("\"name\":\"get_weather\""),
"Should have function name"
);
assert!(
output.contains("\"call_id\":\"call_mD5ggLKk3SMKGPFqFdcpKg6q\""),
"Should have correct call_id"
);
let delta_count = output.matches("event: response.function_call_arguments.delta").count();
let delta_count = output
.matches("event: response.function_call_arguments.delta")
.count();
assert_eq!(delta_count, 4, "Should have 4 delta events");
assert!(!output.contains("response.function_call_arguments.done"), "Should NOT have arguments.done");
assert!(!output.contains("response.output_item.done"), "Should NOT have output_item.done");
assert!(!output.contains("response.completed"), "Should NOT have response.completed");
assert!(
!output.contains("response.function_call_arguments.done"),
"Should NOT have arguments.done"
);
assert!(
!output.contains("response.output_item.done"),
"Should NOT have output_item.done"
);
assert!(
!output.contains("response.completed"),
"Should NOT have response.completed"
);
println!("\nVALIDATION SUMMARY:");
println!("{}", "-".repeat(80));
println!("✓ Lifecycle events: response.created, response.in_progress");
println!("✓ Function call metadata: name='get_weather', call_id='call_mD5ggLKk3SMKGPFqFdcpKg6q'");
println!(
"✓ Function call metadata: name='get_weather', call_id='call_mD5ggLKk3SMKGPFqFdcpKg6q'"
);
println!("✓ Incremental deltas: 4 events (1 initial + 3 argument chunks)");
println!("✓ NO completion events (partial stream, no [DONE])");
println!("✓ Arguments accumulated: '{{\"location\":\"'\n");

View file

@ -1,9 +1,9 @@
use crate::providers::streaming_response::ProviderStreamResponse;
use crate::providers::streaming_response::ProviderStreamResponseType;
use crate::apis::streaming_shapes::chat_completions_streaming_buffer::OpenAIChatCompletionsStreamBuffer;
use crate::apis::streaming_shapes::anthropic_streaming_buffer::AnthropicMessagesStreamBuffer;
use crate::apis::streaming_shapes::chat_completions_streaming_buffer::OpenAIChatCompletionsStreamBuffer;
use crate::apis::streaming_shapes::passthrough_streaming_buffer::PassthroughStreamBuffer;
use crate::apis::streaming_shapes::responses_api_streaming_buffer::ResponsesAPIStreamBuffer;
use crate::providers::streaming_response::ProviderStreamResponse;
use crate::providers::streaming_response::ProviderStreamResponseType;
use serde::{Deserialize, Serialize};
use std::error::Error;
use std::fmt;
@ -37,7 +37,7 @@ pub trait SseStreamBufferTrait: Send + Sync {
///
/// # Returns
/// Bytes ready for wire transmission (may be empty if no events were accumulated)
fn into_bytes(&mut self) -> Vec<u8>;
fn to_bytes(&mut self) -> Vec<u8>;
}
/// Unified SSE Stream Buffer enum that provides a zero-cost abstraction
@ -45,7 +45,7 @@ pub enum SseStreamBuffer {
Passthrough(PassthroughStreamBuffer),
OpenAIChatCompletions(OpenAIChatCompletionsStreamBuffer),
AnthropicMessages(AnthropicMessagesStreamBuffer),
OpenAIResponses(ResponsesAPIStreamBuffer),
OpenAIResponses(Box<ResponsesAPIStreamBuffer>),
}
impl SseStreamBufferTrait for SseStreamBuffer {
@ -58,12 +58,12 @@ impl SseStreamBufferTrait for SseStreamBuffer {
}
}
fn into_bytes(&mut self) -> Vec<u8> {
fn to_bytes(&mut self) -> Vec<u8> {
match self {
Self::Passthrough(buffer) => buffer.into_bytes(),
Self::OpenAIChatCompletions(buffer) => buffer.into_bytes(),
Self::AnthropicMessages(buffer) => buffer.into_bytes(),
Self::OpenAIResponses(buffer) => buffer.into_bytes(),
Self::Passthrough(buffer) => buffer.to_bytes(),
Self::OpenAIChatCompletions(buffer) => buffer.to_bytes(),
Self::AnthropicMessages(buffer) => buffer.to_bytes(),
Self::OpenAIResponses(buffer) => buffer.to_bytes(),
}
}
}
@ -99,7 +99,7 @@ impl SseEvent {
let sse_string: String = response.clone().into();
SseEvent {
data: None, // Data is embedded in sse_transformed_lines
data: None, // Data is embedded in sse_transformed_lines
event: None, // Event type is embedded in sse_transformed_lines
raw_line: sse_string.clone(),
sse_transformed_lines: sse_string,
@ -149,10 +149,8 @@ impl FromStr for SseEvent {
});
}
if trimmed_line.starts_with("data: ") {
let data: String = trimmed_line[6..].to_string(); // Remove "data: " prefix
// Allow empty data content after "data: " prefix
// This handles cases like "data: " followed by newline
if let Some(stripped) = trimmed_line.strip_prefix("data: ") {
let data: String = stripped.to_string();
if data.trim().is_empty() {
return Err(SseParseError {
message: "Empty data field after 'data: ' prefix".to_string(),
@ -166,8 +164,8 @@ impl FromStr for SseEvent {
sse_transformed_lines: line.to_string(),
provider_stream_response: None,
})
} else if trimmed_line.starts_with("event: ") {
let event_type = trimmed_line[7..].to_string();
} else if let Some(stripped) = trimmed_line.strip_prefix("event: ") {
let event_type = stripped.to_string();
if event_type.is_empty() {
return Err(SseParseError {
message: "Empty event field is not a valid SSE event".to_string(),
@ -183,7 +181,10 @@ impl FromStr for SseEvent {
})
} else {
Err(SseParseError {
message: format!("Line does not start with 'data: ' or 'event: ': {}", trimmed_line),
message: format!(
"Line does not start with 'data: ' or 'event: ': {}",
trimmed_line
),
})
}
}
@ -196,16 +197,16 @@ impl fmt::Display for SseEvent {
}
// Into implementation to convert SseEvent to bytes for response buffer
impl Into<Vec<u8>> for SseEvent {
fn into(self) -> Vec<u8> {
impl From<SseEvent> for Vec<u8> {
fn from(val: SseEvent) -> Self {
// For generated events (like ResponsesAPI), sse_transformed_lines already includes trailing \n\n
// For parsed events (like passthrough), we need to add the \n\n separator
if self.sse_transformed_lines.ends_with("\n\n") {
if val.sse_transformed_lines.ends_with("\n\n") {
// Already properly formatted with trailing newlines
self.sse_transformed_lines.into_bytes()
val.sse_transformed_lines.into_bytes()
} else {
// Add SSE event separator
format!("{}\n\n", self.sse_transformed_lines).into_bytes()
format!("{}\n\n", val.sse_transformed_lines).into_bytes()
}
}
}

View file

@ -10,6 +10,12 @@ pub struct SseChunkProcessor {
incomplete_event_buffer: Vec<u8>,
}
impl Default for SseChunkProcessor {
fn default() -> Self {
Self::new()
}
}
impl SseChunkProcessor {
pub fn new() -> Self {
Self {
@ -93,8 +99,8 @@ impl SseChunkProcessor {
#[cfg(test)]
mod tests {
use super::*;
use crate::clients::endpoints::{SupportedAPIsFromClient, SupportedUpstreamAPIs};
use crate::apis::openai::OpenAIApi;
use crate::clients::endpoints::{SupportedAPIsFromClient, SupportedUpstreamAPIs};
#[test]
fn test_complete_events_process_immediately() {
@ -104,7 +110,9 @@ mod tests {
let chunk1 = b"data: {\"id\":\"chatcmpl-123\",\"object\":\"chat.completion.chunk\",\"created\":1234567890,\"model\":\"gpt-4o\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"Hello\"},\"finish_reason\":null}]}\n\n";
let events = processor.process_chunk(chunk1, &client_api, &upstream_api).unwrap();
let events = processor
.process_chunk(chunk1, &client_api, &upstream_api)
.unwrap();
assert_eq!(events.len(), 1);
assert!(!processor.has_buffered_data());
@ -119,18 +127,28 @@ mod tests {
// First chunk with incomplete JSON
let chunk1 = b"data: {\"id\":\"chatcmpl-123\",\"object\":\"chat.completion.chu";
let events1 = processor.process_chunk(chunk1, &client_api, &upstream_api).unwrap();
let events1 = processor
.process_chunk(chunk1, &client_api, &upstream_api)
.unwrap();
assert_eq!(events1.len(), 0, "Incomplete event should not be processed");
assert!(processor.has_buffered_data(), "Incomplete data should be buffered");
assert!(
processor.has_buffered_data(),
"Incomplete data should be buffered"
);
// Second chunk completes the JSON
let chunk2 = b"nk\",\"created\":1234567890,\"model\":\"gpt-4o\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"Hello\"},\"finish_reason\":null}]}\n\n";
let events2 = processor.process_chunk(chunk2, &client_api, &upstream_api).unwrap();
let events2 = processor
.process_chunk(chunk2, &client_api, &upstream_api)
.unwrap();
assert_eq!(events2.len(), 1, "Complete event should be processed");
assert!(!processor.has_buffered_data(), "Buffer should be cleared after completion");
assert!(
!processor.has_buffered_data(),
"Buffer should be cleared after completion"
);
}
#[test]
@ -142,10 +160,15 @@ mod tests {
// Chunk with 2 complete events and 1 incomplete
let chunk = b"data: {\"id\":\"chatcmpl-123\",\"object\":\"chat.completion.chunk\",\"created\":1234567890,\"model\":\"gpt-4o\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"A\"},\"finish_reason\":null}]}\n\ndata: {\"id\":\"chatcmpl-124\",\"object\":\"chat.completion.chunk\",\"created\":1234567890,\"model\":\"gpt-4o\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"B\"},\"finish_reason\":null}]}\n\ndata: {\"id\":\"chatcmpl-125\",\"object\":\"chat.completion.chu";
let events = processor.process_chunk(chunk, &client_api, &upstream_api).unwrap();
let events = processor
.process_chunk(chunk, &client_api, &upstream_api)
.unwrap();
assert_eq!(events.len(), 2, "Two complete events should be processed");
assert!(processor.has_buffered_data(), "Incomplete third event should be buffered");
assert!(
processor.has_buffered_data(),
"Incomplete third event should be buffered"
);
}
#[test]
@ -171,11 +194,23 @@ data: {"type":"content_block_stop","index":0}
Ok(events) => {
println!("Successfully processed {} events", events.len());
for (i, event) in events.iter().enumerate() {
println!("Event {}: event={:?}, has_data={}", i, event.event, event.data.is_some());
println!(
"Event {}: event={:?}, has_data={}",
i,
event.event,
event.data.is_some()
);
}
// Should successfully process both events (signature_delta + content_block_stop)
assert!(events.len() >= 2, "Should process at least 2 complete events (signature_delta + stop), got {}", events.len());
assert!(!processor.has_buffered_data(), "Complete events should not be buffered");
assert!(
events.len() >= 2,
"Should process at least 2 complete events (signature_delta + stop), got {}",
events.len()
);
assert!(
!processor.has_buffered_data(),
"Complete events should not be buffered"
);
}
Err(e) => {
panic!("Failed to process signature_delta chunk - this means SignatureDelta is not properly handled: {}", e);
@ -194,12 +229,21 @@ data: {"type":"content_block_stop","index":0}
// Second event is valid and should be processed
let chunk = b"data: {\"id\":\"chatcmpl-123\",\"object\":\"chat.completion.chunk\",\"created\":1234567890,\"model\":\"gpt-4o\",\"choices\":[{\"index\":0,\"delta\":{\"unsupported_field_causing_validation_error\":true},\"finish_reason\":null}]}\n\ndata: {\"id\":\"chatcmpl-124\",\"object\":\"chat.completion.chunk\",\"created\":1234567890,\"model\":\"gpt-4o\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"Hello\"},\"finish_reason\":null}]}\n\n";
let events = processor.process_chunk(chunk, &client_api, &upstream_api).unwrap();
let events = processor
.process_chunk(chunk, &client_api, &upstream_api)
.unwrap();
// Should skip the invalid event and process the valid one
// (If we were buffering all errors, we'd get 0 events and have buffered data)
assert!(events.len() >= 1, "Should process at least the valid event, got {} events", events.len());
assert!(!processor.has_buffered_data(), "Invalid (non-incomplete) events should not be buffered");
assert!(
!events.is_empty(),
"Should process at least the valid event, got {} events",
events.len()
);
assert!(
!processor.has_buffered_data(),
"Invalid (non-incomplete) events should not be buffered"
);
}
#[test]
@ -227,14 +271,27 @@ data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text
match result {
Ok(events) => {
println!("Processed {} events (unsupported event should be skipped)", events.len());
println!(
"Processed {} events (unsupported event should be skipped)",
events.len()
);
// Should process the 2 valid text_delta events and skip the unsupported one
// We expect at least 2 events (the valid ones), unsupported should be skipped
assert!(events.len() >= 2, "Should process at least 2 valid events, got {}", events.len());
assert!(!processor.has_buffered_data(), "Unsupported events should be skipped, not buffered");
assert!(
events.len() >= 2,
"Should process at least 2 valid events, got {}",
events.len()
);
assert!(
!processor.has_buffered_data(),
"Unsupported events should be skipped, not buffered"
);
}
Err(e) => {
panic!("Should not fail on unsupported delta type, should skip it: {}", e);
panic!(
"Should not fail on unsupported delta type, should skip it: {}",
e
);
}
}
}

View file

@ -135,7 +135,10 @@ impl SupportedAPIsFromClient {
ProviderId::AzureOpenAI => {
if request_path.starts_with("/v1/") {
let suffix = endpoint_suffix.trim_start_matches('/');
build_endpoint("/openai/deployments", &format!("/{}/{}?api-version=2025-01-01-preview", model_id, suffix))
build_endpoint(
"/openai/deployments",
&format!("/{}/{}?api-version=2025-01-01-preview", model_id, suffix),
)
} else {
build_endpoint("/v1", endpoint_suffix)
}
@ -163,19 +166,21 @@ impl SupportedAPIsFromClient {
};
match self {
SupportedAPIsFromClient::AnthropicMessagesAPI(AnthropicApi::Messages) => match provider_id {
ProviderId::Anthropic => build_endpoint("/v1", "/messages"),
ProviderId::AmazonBedrock => {
if request_path.starts_with("/v1/") && !is_streaming {
build_endpoint("", &format!("/model/{}/converse", model_id))
} else if request_path.starts_with("/v1/") && is_streaming {
build_endpoint("", &format!("/model/{}/converse-stream", model_id))
} else {
build_endpoint("/v1", "/chat/completions")
SupportedAPIsFromClient::AnthropicMessagesAPI(AnthropicApi::Messages) => {
match provider_id {
ProviderId::Anthropic => build_endpoint("/v1", "/messages"),
ProviderId::AmazonBedrock => {
if request_path.starts_with("/v1/") && !is_streaming {
build_endpoint("", &format!("/model/{}/converse", model_id))
} else if request_path.starts_with("/v1/") && is_streaming {
build_endpoint("", &format!("/model/{}/converse-stream", model_id))
} else {
build_endpoint("/v1", "/chat/completions")
}
}
_ => build_endpoint("/v1", "/chat/completions"),
}
_ => build_endpoint("/v1", "/chat/completions"),
},
}
SupportedAPIsFromClient::OpenAIResponsesAPI(_) => {
// For Responses API, check if provider supports it, otherwise translate to chat/completions
match provider_id {
@ -193,7 +198,6 @@ impl SupportedAPIsFromClient {
}
}
impl SupportedUpstreamAPIs {
/// Create a SupportedUpstreamApi from an endpoint path
pub fn from_endpoint(endpoint: &str) -> Option<Self> {
@ -216,17 +220,17 @@ impl SupportedUpstreamAPIs {
return Some(SupportedUpstreamAPIs::AmazonBedrockConverse(bedrock_api))
}
AmazonBedrockApi::ConverseStream => {
return Some(SupportedUpstreamAPIs::AmazonBedrockConverseStream(bedrock_api))
return Some(SupportedUpstreamAPIs::AmazonBedrockConverseStream(
bedrock_api,
))
}
}
}
None
}
}
/// Get all supported endpoint paths
pub fn supported_endpoints() -> Vec<&'static str> {
let mut endpoints = Vec::new();
@ -269,9 +273,9 @@ mod tests {
assert!(SupportedAPIsFromClient::from_endpoint("/v1/messages").is_some());
// Unsupported endpoints
assert!(!SupportedAPIsFromClient::from_endpoint("/v1/unknown").is_some());
assert!(!SupportedAPIsFromClient::from_endpoint("/v2/chat").is_some());
assert!(!SupportedAPIsFromClient::from_endpoint("").is_some());
assert!(SupportedAPIsFromClient::from_endpoint("/v1/unknown").is_none());
assert!(SupportedAPIsFromClient::from_endpoint("/v2/chat").is_none());
assert!(SupportedAPIsFromClient::from_endpoint("").is_none());
}
#[test]

View file

@ -12,11 +12,9 @@ pub use aws_smithy_eventstream::frame::DecodedFrame;
pub use providers::id::ProviderId;
pub use providers::request::{ProviderRequest, ProviderRequestError, ProviderRequestType};
pub use providers::response::{
ProviderResponse, ProviderResponseType, TokenUsage, ProviderResponseError
};
pub use providers::streaming_response::{
ProviderStreamResponse, ProviderStreamResponseType
ProviderResponse, ProviderResponseError, ProviderResponseType, TokenUsage,
};
pub use providers::streaming_response::{ProviderStreamResponse, ProviderStreamResponseType};
//TODO: Refactor such that commons doesn't depend on Hermes. For now this will clean up strings
pub const CHAT_COMPLETIONS_PATH: &str = "/v1/chat/completions";
@ -87,11 +85,17 @@ mod tests {
let done_event = streaming_iter.next();
assert!(done_event.is_some(), "Should get [DONE] event");
let done_event = done_event.unwrap();
assert!(done_event.is_done(), "[DONE] event should be marked as done");
assert!(
done_event.is_done(),
"[DONE] event should be marked as done"
);
// After [DONE], iterator should return None
let final_event = streaming_iter.next();
assert!(final_event.is_none(), "Iterator should return None after [DONE]");
assert!(
final_event.is_none(),
"Iterator should return None after [DONE]"
);
}
/// Test AWS Event Stream decoding for Bedrock ConverseStream responses.
@ -130,7 +134,7 @@ mod tests {
let mut content_chunks = Vec::new();
// Simulate chunked network arrivals - process as data comes in
let chunk_sizes = vec![50, 100, 75, 200, 150, 300, 500, 1000];
let chunk_sizes = [50, 100, 75, 200, 150, 300, 500, 1000];
let mut offset = 0;
let mut chunk_num = 0;

View file

@ -59,10 +59,9 @@ impl ProviderId {
(ProviderId::Anthropic, SupportedAPIsFromClient::AnthropicMessagesAPI(_)) => {
SupportedUpstreamAPIs::AnthropicMessagesAPI(AnthropicApi::Messages)
}
(
ProviderId::Anthropic,
SupportedAPIsFromClient::OpenAIChatCompletions(_),
) => SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
(ProviderId::Anthropic, SupportedAPIsFromClient::OpenAIChatCompletions(_)) => {
SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions)
}
// Anthropic doesn't support Responses API, fall back to chat completions
(ProviderId::Anthropic, SupportedAPIsFromClient::OpenAIResponsesAPI(_)) => {

View file

@ -10,6 +10,7 @@ use serde_json::Value;
use std::collections::HashMap;
use std::error::Error;
use std::fmt;
#[allow(clippy::large_enum_variant)]
#[derive(Clone, Debug)]
pub enum ProviderRequestType {
ChatCompletionsRequest(ChatCompletionsRequest),
@ -197,7 +198,9 @@ impl ProviderRequest for ProviderRequestType {
impl TryFrom<(&[u8], &SupportedAPIsFromClient)> for ProviderRequestType {
type Error = std::io::Error;
fn try_from((bytes, client_api): (&[u8], &SupportedAPIsFromClient)) -> Result<Self, Self::Error> {
fn try_from(
(bytes, client_api): (&[u8], &SupportedAPIsFromClient),
) -> Result<Self, Self::Error> {
// Use SupportedApi to determine the appropriate request type
match client_api {
SupportedAPIsFromClient::OpenAIChatCompletions(_) => {
@ -882,7 +885,7 @@ mod tests {
ProviderRequestType::BedrockConverse(bedrock_req) => {
assert_eq!(bedrock_req.model_id, "gpt-4o");
// Bedrock receives the converted request through ChatCompletions
assert!(!bedrock_req.messages.is_none());
assert!(bedrock_req.messages.is_some());
}
_ => panic!("Expected BedrockConverse variant"),
}
@ -913,7 +916,9 @@ mod tests {
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.message.contains("ResponsesAPI can only be used as a client API"));
assert!(err
.message
.contains("ResponsesAPI can only be used as a client API"));
}
#[test]
@ -953,7 +958,9 @@ mod tests {
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.message.contains("ResponsesAPI can only be used as a client API"));
assert!(err
.message
.contains("ResponsesAPI can only be used as a client API"));
}
#[test]
@ -1023,9 +1030,7 @@ mod tests {
role: MessagesRole::User,
content: MessagesMessageContent::Single("Hello!".to_string()),
}],
system: Some(MessagesSystemPrompt::Single(
"You are helpful".to_string(),
)),
system: Some(MessagesSystemPrompt::Single("You are helpful".to_string())),
max_tokens: 100,
container: None,
mcp_servers: None,
@ -1046,14 +1051,8 @@ mod tests {
// Should have system message + user message
assert_eq!(messages.len(), 2);
assert_eq!(
messages[0].role,
crate::apis::openai::Role::System
);
assert_eq!(
messages[1].role,
crate::apis::openai::Role::User
);
assert_eq!(messages[0].role, crate::apis::openai::Role::System);
assert_eq!(messages[1].role, crate::apis::openai::Role::User);
}
#[test]
@ -1094,13 +1093,7 @@ mod tests {
// Should have system message (instructions) + user message (input)
assert_eq!(messages.len(), 2);
assert_eq!(
messages[0].role,
crate::apis::openai::Role::System
);
assert_eq!(
messages[1].role,
crate::apis::openai::Role::User
);
assert_eq!(messages[0].role, crate::apis::openai::Role::System);
assert_eq!(messages[1].role, crate::apis::openai::Role::User);
}
}

View file

@ -1,7 +1,3 @@
use serde::Serialize;
use std::convert::TryFrom;
use std::error::Error;
use std::fmt;
use crate::apis::amazon_bedrock::ConverseResponse;
use crate::apis::anthropic::MessagesResponse;
use crate::apis::openai::ChatCompletionsResponse;
@ -9,14 +5,17 @@ use crate::apis::openai_responses::ResponsesAPIResponse;
use crate::clients::endpoints::SupportedAPIsFromClient;
use crate::clients::endpoints::SupportedUpstreamAPIs;
use crate::providers::id::ProviderId;
use serde::Serialize;
use std::convert::TryFrom;
use std::error::Error;
use std::fmt;
#[derive(Serialize, Debug, Clone)]
#[serde(untagged)]
pub enum ProviderResponseType {
ChatCompletionsResponse(ChatCompletionsResponse),
MessagesResponse(MessagesResponse),
ResponsesAPIResponse(ResponsesAPIResponse),
ResponsesAPIResponse(Box<ResponsesAPIResponse>),
}
/// Trait for token usage information
@ -42,7 +41,9 @@ impl ProviderResponse for ProviderResponseType {
match self {
ProviderResponseType::ChatCompletionsResponse(resp) => resp.usage(),
ProviderResponseType::MessagesResponse(resp) => resp.usage(),
ProviderResponseType::ResponsesAPIResponse(resp) => resp.usage.as_ref().map(|u| u as &dyn TokenUsage),
ProviderResponseType::ResponsesAPIResponse(resp) => {
resp.usage.as_ref().map(|u| u as &dyn TokenUsage)
}
}
}
@ -50,11 +51,13 @@ impl ProviderResponse for ProviderResponseType {
match self {
ProviderResponseType::ChatCompletionsResponse(resp) => resp.extract_usage_counts(),
ProviderResponseType::MessagesResponse(resp) => resp.extract_usage_counts(),
ProviderResponseType::ResponsesAPIResponse(resp) => {
resp.usage.as_ref().map(|u| {
(u.input_tokens as usize, u.output_tokens as usize, u.total_tokens as usize)
})
}
ProviderResponseType::ResponsesAPIResponse(resp) => resp.usage.as_ref().map(|u| {
(
u.input_tokens as usize,
u.output_tokens as usize,
u.total_tokens as usize,
)
}),
}
}
}
@ -156,40 +159,44 @@ impl TryFrom<(&[u8], &SupportedAPIsFromClient, &ProviderId)> for ProviderRespons
) => {
let resp: ResponsesAPIResponse = ResponsesAPIResponse::try_from(bytes)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
Ok(ProviderResponseType::ResponsesAPIResponse(resp))
Ok(ProviderResponseType::ResponsesAPIResponse(Box::new(resp)))
}
(
SupportedUpstreamAPIs::OpenAIChatCompletions(_),
SupportedAPIsFromClient::OpenAIResponsesAPI(_),
) => {
let chat_completions_response: ChatCompletionsResponse = ChatCompletionsResponse::try_from(bytes)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
let chat_completions_response: ChatCompletionsResponse =
ChatCompletionsResponse::try_from(bytes)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
// Transform to ResponsesAPI format using the transformer
let responses_resp: ResponsesAPIResponse = chat_completions_response.try_into().map_err(|e| {
std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("Transformation error: {}", e),
)
})?;
Ok(ProviderResponseType::ResponsesAPIResponse(responses_resp))
let responses_resp: ResponsesAPIResponse =
chat_completions_response.try_into().map_err(|e| {
std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("Transformation error: {}", e),
)
})?;
Ok(ProviderResponseType::ResponsesAPIResponse(Box::new(
responses_resp,
)))
}
(
SupportedUpstreamAPIs::AnthropicMessagesAPI(_),
SupportedAPIsFromClient::OpenAIResponsesAPI(_),
) => {
//Chain transform: Anthropic Messages -> OpenAI ChatCompletions -> ResponsesAPI
let anthropic_resp: MessagesResponse = serde_json::from_slice(bytes)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
// Transform to ChatCompletions format using the transformer
let chat_resp: ChatCompletionsResponse = anthropic_resp.try_into().map_err(|e| {
std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("Transformation error: {}", e),
)
})?;
let chat_resp: ChatCompletionsResponse =
anthropic_resp.try_into().map_err(|e| {
std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("Transformation error: {}", e),
)
})?;
let response_api: ResponsesAPIResponse = chat_resp.try_into().map_err(|e| {
std::io::Error::new(
@ -197,7 +204,9 @@ impl TryFrom<(&[u8], &SupportedAPIsFromClient, &ProviderId)> for ProviderRespons
format!("Transformation error: {}", e),
)
})?;
Ok(ProviderResponseType::ResponsesAPIResponse(response_api))
Ok(ProviderResponseType::ResponsesAPIResponse(Box::new(
response_api,
)))
}
(
SupportedUpstreamAPIs::AmazonBedrockConverse(_),
@ -219,10 +228,15 @@ impl TryFrom<(&[u8], &SupportedAPIsFromClient, &ProviderId)> for ProviderRespons
let response_api: ResponsesAPIResponse = chat_resp.try_into().map_err(|e| {
std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("ChatCompletions to ResponsesAPI transformation error: {}", e),
format!(
"ChatCompletions to ResponsesAPI transformation error: {}",
e
),
)
})?;
Ok(ProviderResponseType::ResponsesAPIResponse(response_api))
Ok(ProviderResponseType::ResponsesAPIResponse(Box::new(
response_api,
)))
}
_ => Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
@ -255,8 +269,8 @@ impl Error for ProviderResponseError {
#[cfg(test)]
mod tests {
use super::*;
use crate::apis::openai::OpenAIApi;
use crate::apis::anthropic::AnthropicApi;
use crate::apis::openai::OpenAIApi;
use crate::clients::endpoints::SupportedAPIsFromClient;
use crate::providers::id::ProviderId;
use serde_json::json;

View file

@ -1,18 +1,17 @@
use serde::Serialize;
use std::convert::TryFrom;
use crate::apis::amazon_bedrock::ConverseStreamEvent;
use crate::apis::anthropic::MessagesStreamEvent;
use crate::apis::openai::ChatCompletionsStreamResponse;
use crate::apis::openai_responses::ResponsesAPIStreamEvent;
use crate::apis::streaming_shapes::sse::SseEvent;
use crate::apis::amazon_bedrock::ConverseStreamEvent;
use crate::apis::anthropic::MessagesStreamEvent;
use crate::apis::streaming_shapes::sse::SseStreamBuffer;
use crate::apis::streaming_shapes::{
anthropic_streaming_buffer::AnthropicMessagesStreamBuffer,
chat_completions_streaming_buffer::OpenAIChatCompletionsStreamBuffer,
passthrough_streaming_buffer::PassthroughStreamBuffer,
responses_api_streaming_buffer::ResponsesAPIStreamBuffer,
};
anthropic_streaming_buffer::AnthropicMessagesStreamBuffer,
chat_completions_streaming_buffer::OpenAIChatCompletionsStreamBuffer,
passthrough_streaming_buffer::PassthroughStreamBuffer,
};
use crate::clients::endpoints::SupportedAPIsFromClient;
use crate::clients::endpoints::SupportedUpstreamAPIs;
@ -28,9 +27,18 @@ pub fn needs_buffering(
) -> bool {
match (client_api, upstream_api) {
// Same APIs - no buffering needed
(SupportedAPIsFromClient::OpenAIChatCompletions(_), SupportedUpstreamAPIs::OpenAIChatCompletions(_)) => false,
(SupportedAPIsFromClient::AnthropicMessagesAPI(_), SupportedUpstreamAPIs::AnthropicMessagesAPI(_)) => false,
(SupportedAPIsFromClient::OpenAIResponsesAPI(_), SupportedUpstreamAPIs::OpenAIResponsesAPI(_)) => false,
(
SupportedAPIsFromClient::OpenAIChatCompletions(_),
SupportedUpstreamAPIs::OpenAIChatCompletions(_),
) => false,
(
SupportedAPIsFromClient::AnthropicMessagesAPI(_),
SupportedUpstreamAPIs::AnthropicMessagesAPI(_),
) => false,
(
SupportedAPIsFromClient::OpenAIResponsesAPI(_),
SupportedUpstreamAPIs::OpenAIResponsesAPI(_),
) => false,
// Different APIs - buffering needed
_ => true,
@ -53,15 +61,12 @@ pub fn needs_buffering(
/// // Flush to wire
/// let bytes = buffer.into_bytes();
/// ```
impl TryFrom<(&SupportedAPIsFromClient, &SupportedUpstreamAPIs)>
for SseStreamBuffer
{
impl TryFrom<(&SupportedAPIsFromClient, &SupportedUpstreamAPIs)> for SseStreamBuffer {
type Error = Box<dyn std::error::Error + Send + Sync>;
fn try_from(
(client_api, upstream_api): (&SupportedAPIsFromClient, &SupportedUpstreamAPIs),
) -> Result<Self, Self::Error> {
// If APIs match, use passthrough - no buffering/transformation needed
if !needs_buffering(client_api, upstream_api) {
return Ok(SseStreamBuffer::Passthrough(PassthroughStreamBuffer::new()));
@ -69,14 +74,14 @@ impl TryFrom<(&SupportedAPIsFromClient, &SupportedUpstreamAPIs)>
// APIs differ - use appropriate buffer for client API
match client_api {
SupportedAPIsFromClient::OpenAIChatCompletions(_) => {
Ok(SseStreamBuffer::OpenAIChatCompletions(OpenAIChatCompletionsStreamBuffer::new()))
}
SupportedAPIsFromClient::AnthropicMessagesAPI(_) => {
Ok(SseStreamBuffer::AnthropicMessages(AnthropicMessagesStreamBuffer::new()))
}
SupportedAPIsFromClient::OpenAIChatCompletions(_) => Ok(
SseStreamBuffer::OpenAIChatCompletions(OpenAIChatCompletionsStreamBuffer::new()),
),
SupportedAPIsFromClient::AnthropicMessagesAPI(_) => Ok(
SseStreamBuffer::AnthropicMessages(AnthropicMessagesStreamBuffer::new()),
),
SupportedAPIsFromClient::OpenAIResponsesAPI(_) => {
Ok(SseStreamBuffer::OpenAIResponses(ResponsesAPIStreamBuffer::new()))
Ok(SseStreamBuffer::OpenAIResponses(Box::default()))
}
}
}
@ -88,11 +93,12 @@ impl TryFrom<(&SupportedAPIsFromClient, &SupportedUpstreamAPIs)>
#[derive(Serialize, Debug, Clone)]
#[serde(untagged)]
#[allow(clippy::large_enum_variant)]
pub enum ProviderStreamResponseType {
ChatCompletionsStreamResponse(ChatCompletionsStreamResponse),
MessagesStreamEvent(MessagesStreamEvent),
ConverseStreamEvent(ConverseStreamEvent),
ResponseAPIStreamEvent(ResponsesAPIStreamEvent)
ResponseAPIStreamEvent(Box<ResponsesAPIStreamEvent>),
}
pub trait ProviderStreamResponse: Send + Sync {
@ -145,12 +151,11 @@ impl ProviderStreamResponse for ProviderStreamResponseType {
ProviderStreamResponseType::ResponseAPIStreamEvent(resp) => resp.event_type(),
}
}
}
impl Into<String> for ProviderStreamResponseType {
fn into(self) -> String {
match self {
impl From<ProviderStreamResponseType> for String {
fn from(val: ProviderStreamResponseType) -> String {
match val {
ProviderStreamResponseType::MessagesStreamEvent(event) => {
// Use the Into<String> implementation for proper SSE formatting with event lines
event.into()
@ -161,27 +166,36 @@ impl Into<String> for ProviderStreamResponseType {
}
ProviderStreamResponseType::ResponseAPIStreamEvent(event) => {
// Use the Into<String> implementation for proper SSE formatting with event lines
event.into()
// Clone to work around Box<T> ownership
let cloned = (*event).clone();
cloned.into()
}
ProviderStreamResponseType::ChatCompletionsStreamResponse(_) => {
// For OpenAI, use simple data line format
let json = serde_json::to_string(&self).unwrap_or_default();
let json = serde_json::to_string(&val).unwrap_or_default();
format!("data: {}\n\n", json)
}
}
}
}
// Stream response transformation logic for client API compatibility
impl TryFrom<(&[u8], &SupportedAPIsFromClient, &SupportedUpstreamAPIs)> for ProviderStreamResponseType {
impl TryFrom<(&[u8], &SupportedAPIsFromClient, &SupportedUpstreamAPIs)>
for ProviderStreamResponseType
{
type Error = Box<dyn std::error::Error + Send + Sync>;
fn try_from(
(bytes, client_api, upstream_api): (&[u8], &SupportedAPIsFromClient, &SupportedUpstreamAPIs),
(bytes, client_api, upstream_api): (
&[u8],
&SupportedAPIsFromClient,
&SupportedUpstreamAPIs,
),
) -> Result<Self, Self::Error> {
// Special case: Handle [DONE] marker for OpenAI -> Anthropic conversion
if bytes == b"[DONE]" && matches!(client_api, SupportedAPIsFromClient::AnthropicMessagesAPI(_)) {
if bytes == b"[DONE]"
&& matches!(client_api, SupportedAPIsFromClient::AnthropicMessagesAPI(_))
{
return Ok(ProviderStreamResponseType::MessagesStreamEvent(
crate::apis::anthropic::MessagesStreamEvent::MessageStop,
));
@ -214,9 +228,9 @@ impl TryFrom<(&[u8], &SupportedAPIsFromClient, &SupportedUpstreamAPIs)> for Prov
) => {
let openai_resp: crate::apis::openai::ChatCompletionsStreamResponse =
serde_json::from_slice(bytes)?;
let responses_resp = openai_resp.try_into()?;
let responses_resp: ResponsesAPIStreamEvent = openai_resp.try_into()?;
Ok(ProviderStreamResponseType::ResponseAPIStreamEvent(
responses_resp,
Box::new(responses_resp),
))
}
@ -267,10 +281,11 @@ impl TryFrom<(&[u8], &SupportedAPIsFromClient, &SupportedUpstreamAPIs)> for Prov
// Chain: Bedrock -> ChatCompletions -> ResponsesAPI
let bedrock_resp: crate::apis::amazon_bedrock::ConverseStreamEvent =
serde_json::from_slice(bytes)?;
let chat_resp: crate::apis::openai::ChatCompletionsStreamResponse = bedrock_resp.try_into()?;
let responses_resp = chat_resp.try_into()?;
let chat_resp: crate::apis::openai::ChatCompletionsStreamResponse =
bedrock_resp.try_into()?;
let responses_resp: ResponsesAPIStreamEvent = chat_resp.try_into()?;
Ok(ProviderStreamResponseType::ResponseAPIStreamEvent(
responses_resp,
Box::new(responses_resp),
))
}
_ => Err(std::io::Error::new(
@ -287,7 +302,11 @@ impl TryFrom<(SseEvent, &SupportedAPIsFromClient, &SupportedUpstreamAPIs)> for S
type Error = Box<dyn std::error::Error + Send + Sync>;
fn try_from(
(sse_event, client_api, upstream_api): (SseEvent, &SupportedAPIsFromClient, &SupportedUpstreamAPIs),
(sse_event, client_api, upstream_api): (
SseEvent,
&SupportedAPIsFromClient,
&SupportedUpstreamAPIs,
),
) -> Result<Self, Self::Error> {
// Create a new transformed event based on the original
let mut transformed_event = sse_event;
@ -296,7 +315,11 @@ impl TryFrom<(SseEvent, &SupportedAPIsFromClient, &SupportedUpstreamAPIs)> for S
if transformed_event.is_done() {
// For OpenAI client APIs (ChatCompletions and ResponsesAPI), keep [DONE] as-is
// For Anthropic client API, it will be transformed via ProviderStreamResponseType
if matches!(client_api, SupportedAPIsFromClient::OpenAIChatCompletions(_) | SupportedAPIsFromClient::OpenAIResponsesAPI(_)) {
if matches!(
client_api,
SupportedAPIsFromClient::OpenAIChatCompletions(_)
| SupportedAPIsFromClient::OpenAIResponsesAPI(_)
) {
// Keep the [DONE] marker as-is for OpenAI clients
transformed_event.sse_transformed_lines = "data: [DONE]".to_string();
return Ok(transformed_event);
@ -328,7 +351,7 @@ impl TryFrom<(SseEvent, &SupportedAPIsFromClient, &SupportedUpstreamAPIs)> for S
// OpenAI clients don't expect separate event: lines
// Suppress upstream Anthropic event-only lines
if transformed_event.is_event_only() && transformed_event.event.is_some() {
transformed_event.sse_transformed_lines = format!("\n");
transformed_event.sse_transformed_lines = "\n".to_string();
}
}
_ => {
@ -345,7 +368,8 @@ impl TryFrom<(SseEvent, &SupportedAPIsFromClient, &SupportedUpstreamAPIs)> for S
(
SupportedAPIsFromClient::AnthropicMessagesAPI(_),
SupportedUpstreamAPIs::AnthropicMessagesAPI(_),
) | (
)
| (
SupportedAPIsFromClient::OpenAIResponsesAPI(_),
SupportedUpstreamAPIs::OpenAIResponsesAPI(_),
) => {
@ -415,7 +439,7 @@ impl
openai_event,
))
}
(
(
SupportedUpstreamAPIs::AmazonBedrockConverseStream(_),
SupportedAPIsFromClient::OpenAIResponsesAPI(_),
) => {
@ -428,7 +452,7 @@ impl
openai_chat_completions_event.try_into()?;
Ok(ProviderStreamResponseType::ResponseAPIStreamEvent(
openai_responses_api_event,
Box::new(openai_responses_api_event),
))
}
_ => Err("Unsupported API combination for event-stream decoding".into()),
@ -445,11 +469,11 @@ impl
mod tests {
use super::*;
use crate::apis::streaming_shapes::amazon_bedrock_binary_frame::BedrockBinaryFrameDecoder;
use crate::clients::endpoints::SupportedAPIsFromClient;
use crate::apis::streaming_shapes::sse::SseStreamIter;
use crate::clients::endpoints::SupportedAPIsFromClient;
use serde_json::json;
#[test]
#[test]
fn test_sse_event_parsing() {
// Test valid SSE data line
let line = "data: {\"id\":\"test\",\"object\":\"chat.completion.chunk\"}\n\n";
@ -792,7 +816,7 @@ mod tests {
// Simulate chunked network arrivals with realistic chunk sizes
// Using varying chunk sizes to test partial frame handling
let mut buffer = BytesMut::new();
let chunk_size_pattern = vec![500, 1000, 750, 1200, 800, 1500];
let chunk_size_pattern = [500, 1000, 750, 1200, 800, 1500];
let mut offset = 0;
let mut total_frames = 0;
let mut chunk_num = 0;
@ -837,7 +861,7 @@ mod tests {
);
}
#[test]
#[test]
fn test_bedrock_decoded_frame_to_provider_response() {
test_bedrock_conversion(false);
}
@ -879,8 +903,9 @@ mod tests {
let mut decoder = BedrockBinaryFrameDecoder::new(&mut buffer);
let client_api =
SupportedAPIsFromClient::AnthropicMessagesAPI(crate::apis::anthropic::AnthropicApi::Messages);
let client_api = SupportedAPIsFromClient::AnthropicMessagesAPI(
crate::apis::anthropic::AnthropicApi::Messages,
);
let upstream_api = SupportedUpstreamAPIs::AmazonBedrockConverseStream(
crate::apis::amazon_bedrock::AmazonBedrockApi::ConverseStream,
);
@ -966,8 +991,9 @@ mod tests {
let mut decoder = BedrockBinaryFrameDecoder::new(&mut buffer);
let client_api =
SupportedAPIsFromClient::AnthropicMessagesAPI(crate::apis::anthropic::AnthropicApi::Messages);
let client_api = SupportedAPIsFromClient::AnthropicMessagesAPI(
crate::apis::anthropic::AnthropicApi::Messages,
);
let upstream_api = SupportedUpstreamAPIs::AmazonBedrockConverseStream(
crate::apis::amazon_bedrock::AmazonBedrockApi::ConverseStream,
);
@ -1051,7 +1077,6 @@ mod tests {
);
}
#[test]
fn test_sse_event_transformation_openai_to_anthropic_message_delta() {
use crate::apis::anthropic::AnthropicApi;
@ -1079,8 +1104,8 @@ mod tests {
let sse_event = SseEvent {
data: Some(openai_stream_chunk.to_string()),
event: None,
raw_line: format!("data: {}", openai_stream_chunk.to_string()),
sse_transformed_lines: format!("data: {}", openai_stream_chunk.to_string()),
raw_line: format!("data: {}", openai_stream_chunk),
sse_transformed_lines: format!("data: {}", openai_stream_chunk),
provider_stream_response: None,
};
@ -1101,7 +1126,8 @@ mod tests {
// Verify the event was transformed to Anthropic format
// This should contain message_delta with stop_reason and usage
assert!(
buffer.contains("event: message_delta") || buffer.contains("\"type\":\"message_delta\""),
buffer.contains("event: message_delta")
|| buffer.contains("\"type\":\"message_delta\""),
"Should contain message_delta in transformed event"
);
@ -1134,8 +1160,8 @@ mod tests {
let sse_event = SseEvent {
data: Some(openai_stream_chunk.to_string()),
event: None,
raw_line: format!("data: {}", openai_stream_chunk.to_string()),
sse_transformed_lines: format!("data: {}", openai_stream_chunk.to_string()),
raw_line: format!("data: {}", openai_stream_chunk),
sse_transformed_lines: format!("data: {}", openai_stream_chunk),
provider_stream_response: None,
};
@ -1223,8 +1249,8 @@ mod tests {
let sse_event = SseEvent {
data: Some(anthropic_event.to_string()),
event: None,
raw_line: format!("data: {}", anthropic_event.to_string()),
sse_transformed_lines: format!("data: {}", anthropic_event.to_string()),
raw_line: format!("data: {}", anthropic_event),
sse_transformed_lines: format!("data: {}", anthropic_event),
provider_stream_response: None,
};
@ -1314,8 +1340,8 @@ mod tests {
let sse_event = SseEvent {
data: Some(openai_stream_chunk.to_string()),
event: None,
raw_line: format!("data: {}", openai_stream_chunk.to_string()),
sse_transformed_lines: format!("data: {}", openai_stream_chunk.to_string()),
raw_line: format!("data: {}", openai_stream_chunk),
sse_transformed_lines: format!("data: {}", openai_stream_chunk),
provider_stream_response: None,
};

View file

@ -11,11 +11,11 @@ pub trait ExtractText {
/// Trait for utility functions on content collections
pub trait ContentUtils<T> {
fn extract_tool_calls(&self) -> Result<Option<Vec<ToolCall>>, TransformError>;
fn split_for_openai(
&self,
) -> Result<(Vec<ContentPart>, Vec<ToolCall>, Vec<(String, String, bool)>), TransformError>;
fn split_for_openai(&self) -> Result<SplitForOpenAIResult, TransformError>;
}
pub type SplitForOpenAIResult = (Vec<ContentPart>, Vec<ToolCall>, Vec<(String, String, bool)>);
/// Helper to create a current unix timestamp
pub fn current_timestamp() -> u64 {
SystemTime::now()

View file

@ -38,7 +38,7 @@ impl TryFrom<AnthropicMessagesRequest> for ChatCompletionsRequest {
}
// Convert tools and tool choice
let openai_tools = req.tools.map(|tools| convert_anthropic_tools(tools));
let openai_tools = req.tools.map(convert_anthropic_tools);
let (openai_tool_choice, parallel_tool_calls) =
convert_anthropic_tool_choice(req.tool_choice);
@ -218,18 +218,18 @@ impl TryFrom<MessagesMessage> for Vec<Message> {
}
// Role Conversions
impl Into<Role> for MessagesRole {
fn into(self) -> Role {
match self {
impl From<MessagesRole> for Role {
fn from(val: MessagesRole) -> Self {
match val {
MessagesRole::User => Role::User,
MessagesRole::Assistant => Role::Assistant,
}
}
}
impl Into<MessagesStopReason> for FinishReason {
fn into(self) -> MessagesStopReason {
match self {
impl From<FinishReason> for MessagesStopReason {
fn from(val: FinishReason) -> Self {
match val {
FinishReason::Stop => MessagesStopReason::EndTurn,
FinishReason::Length => MessagesStopReason::MaxTokens,
FinishReason::ToolCalls => MessagesStopReason::ToolUse,
@ -239,11 +239,11 @@ impl Into<MessagesStopReason> for FinishReason {
}
}
impl Into<MessagesUsage> for Usage {
fn into(self) -> MessagesUsage {
impl From<Usage> for MessagesUsage {
fn from(val: Usage) -> Self {
MessagesUsage {
input_tokens: self.prompt_tokens,
output_tokens: self.completion_tokens,
input_tokens: val.prompt_tokens,
output_tokens: val.completion_tokens,
cache_creation_input_tokens: None,
cache_read_input_tokens: None,
}
@ -251,9 +251,9 @@ impl Into<MessagesUsage> for Usage {
}
// System Prompt Conversions
impl Into<Message> for MessagesSystemPrompt {
fn into(self) -> Message {
let system_content = match self {
impl From<MessagesSystemPrompt> for Message {
fn from(val: MessagesSystemPrompt) -> Self {
let system_content = match val {
MessagesSystemPrompt::Single(text) => MessageContent::Text(text),
MessagesSystemPrompt::Blocks(blocks) => MessageContent::Text(blocks.extract_text()),
};
@ -384,12 +384,8 @@ impl TryFrom<MessagesMessage> for BedrockMessage {
ToolResultContent::Blocks(blocks) => {
let mut result_blocks = Vec::new();
for result_block in blocks {
match result_block {
crate::apis::anthropic::MessagesContentBlock::Text { text, .. } => {
result_blocks.push(ToolResultContentBlock::Text { text });
}
// For now, skip other content types in tool results
_ => {}
if let crate::apis::anthropic::MessagesContentBlock::Text { text, .. } = result_block {
result_blocks.push(ToolResultContentBlock::Text { text });
}
}
result_blocks

View file

@ -14,7 +14,8 @@ use crate::apis::openai::{
};
use crate::apis::openai_responses::{
ResponsesAPIRequest, InputContent, InputItem, InputParam, MessageRole, Modality, ReasoningEffort, Tool as ResponsesTool, ToolChoice as ResponsesToolChoice
InputContent, InputItem, InputParam, MessageRole, Modality, ReasoningEffort,
ResponsesAPIRequest, Tool as ResponsesTool, ToolChoice as ResponsesToolChoice,
};
use crate::clients::TransformError;
use crate::transforms::lib::ExtractText;
@ -27,9 +28,9 @@ type AnthropicMessagesRequest = MessagesRequest;
// MAIN REQUEST TRANSFORMATIONS
// ============================================================================
impl Into<MessagesSystemPrompt> for Message {
fn into(self) -> MessagesSystemPrompt {
let system_text = match self.content {
impl From<Message> for MessagesSystemPrompt {
fn from(val: Message) -> Self {
let system_text = match val.content {
MessageContent::Text(text) => text,
MessageContent::Parts(parts) => parts.extract_text(),
};
@ -163,7 +164,7 @@ impl TryFrom<Message> for BedrockMessage {
let has_tool_calls = message
.tool_calls
.as_ref()
.map_or(false, |calls| !calls.is_empty());
.is_some_and(|calls| !calls.is_empty());
// Add text content if it's non-empty, or if we have no tool calls (to avoid empty content)
if !text_content.is_empty() {
@ -252,7 +253,6 @@ impl TryFrom<ResponsesAPIRequest> for ChatCompletionsRequest {
type Error = TransformError;
fn try_from(req: ResponsesAPIRequest) -> Result<Self, Self::Error> {
// Convert input to messages
let messages = match req.input {
InputParam::Text(text) => {
@ -282,50 +282,27 @@ impl TryFrom<ResponsesAPIRequest> for ChatCompletionsRequest {
// Convert each input item
for item in items {
match item {
InputItem::Message(input_msg) => {
let role = match input_msg.role {
MessageRole::User => Role::User,
MessageRole::Assistant => Role::Assistant,
MessageRole::System => Role::System,
MessageRole::Developer => Role::System, // Map developer to system
};
if let InputItem::Message(input_msg) = item {
let role = match input_msg.role {
MessageRole::User => Role::User,
MessageRole::Assistant => Role::Assistant,
MessageRole::System => Role::System,
MessageRole::Developer => Role::System, // Map developer to system
};
// Convert content based on MessageContent type
let content = match &input_msg.content {
crate::apis::openai_responses::MessageContent::Text(text) => {
// Simple text content
MessageContent::Text(text.clone())
}
crate::apis::openai_responses::MessageContent::Items(content_items) => {
// Check if it's a single text item (can use simple text format)
if content_items.len() == 1 {
if let InputContent::InputText { text } = &content_items[0] {
MessageContent::Text(text.clone())
} else {
// Single non-text item - use parts format
MessageContent::Parts(
content_items.iter()
.filter_map(|c| match c {
InputContent::InputText { text } => {
Some(crate::apis::openai::ContentPart::Text { text: text.clone() })
}
InputContent::InputImage { image_url, .. } => {
Some(crate::apis::openai::ContentPart::ImageUrl {
image_url: crate::apis::openai::ImageUrl {
url: image_url.clone(),
detail: None,
}
})
}
InputContent::InputFile { .. } => None, // Skip files for now
InputContent::InputAudio { .. } => None, // Skip audio for now
})
.collect()
)
}
// Convert content based on MessageContent type
let content = match &input_msg.content {
crate::apis::openai_responses::MessageContent::Text(text) => {
// Simple text content
MessageContent::Text(text.clone())
}
crate::apis::openai_responses::MessageContent::Items(content_items) => {
// Check if it's a single text item (can use simple text format)
if content_items.len() == 1 {
if let InputContent::InputText { text } = &content_items[0] {
MessageContent::Text(text.clone())
} else {
// Multiple content items - convert to parts
// Single non-text item - use parts format
MessageContent::Parts(
content_items.iter()
.filter_map(|c| match c {
@ -346,20 +323,41 @@ impl TryFrom<ResponsesAPIRequest> for ChatCompletionsRequest {
.collect()
)
}
} else {
// Multiple content items - convert to parts
MessageContent::Parts(
content_items
.iter()
.filter_map(|c| match c {
InputContent::InputText { text } => {
Some(crate::apis::openai::ContentPart::Text {
text: text.clone(),
})
}
InputContent::InputImage { image_url, .. } => Some(
crate::apis::openai::ContentPart::ImageUrl {
image_url: crate::apis::openai::ImageUrl {
url: image_url.clone(),
detail: None,
},
},
),
InputContent::InputFile { .. } => None, // Skip files for now
InputContent::InputAudio { .. } => None, // Skip audio for now
})
.collect(),
)
}
};
}
};
converted_messages.push(Message {
role,
content,
name: None,
tool_call_id: None,
tool_calls: None,
});
}
// Skip non-message items (references, outputs) for now
// These would need special handling in chat completions format
_ => {}
converted_messages.push(Message {
role,
content,
name: None,
tool_call_id: None,
tool_calls: None,
});
}
}
@ -474,7 +472,7 @@ impl TryFrom<ChatCompletionsRequest> for AnthropicMessagesRequest {
}
// Convert tools and tool choice
let anthropic_tools = req.tools.map(|tools| convert_openai_tools(tools));
let anthropic_tools = req.tools.map(convert_openai_tools);
let anthropic_tool_choice =
convert_openai_tool_choice(req.tool_choice, req.parallel_tool_calls);

View file

@ -13,18 +13,14 @@ use crate::apis::openai_responses::{
pub fn convert_responses_output_to_input_items(output: &OutputItem) -> Option<InputItem> {
match output {
// Convert output messages to input messages
OutputItem::Message {
role, content, ..
} => {
OutputItem::Message { role, content, .. } => {
let input_content: Vec<InputContent> = content
.iter()
.filter_map(|c| match c {
OutputContent::OutputText { text, .. } => Some(InputContent::InputText {
text: text.clone(),
}),
OutputContent::OutputAudio {
data, ..
} => Some(InputContent::InputAudio {
OutputContent::OutputText { text, .. } => {
Some(InputContent::InputText { text: text.clone() })
}
OutputContent::OutputAudio { data, .. } => Some(InputContent::InputAudio {
data: data.clone(),
format: None, // Format not preserved in output
}),
@ -84,7 +80,7 @@ pub fn outputs_to_inputs(outputs: &[OutputItem]) -> Vec<InputItem> {
#[cfg(test)]
mod tests {
use super::*;
use crate::apis::openai_responses::{OutputItemStatus};
use crate::apis::openai_responses::OutputItemStatus;
#[test]
fn test_output_message_to_input() {
@ -135,14 +131,12 @@ mod tests {
InputItem::Message(msg) => {
assert!(matches!(msg.role, MessageRole::Assistant));
match &msg.content {
MessageContent::Items(items) => {
match &items[0] {
InputContent::InputText { text } => {
assert!(text.contains("get_weather"));
}
_ => panic!("Expected InputText"),
MessageContent::Items(items) => match &items[0] {
InputContent::InputText { text } => {
assert!(text.contains("get_weather"));
}
}
_ => panic!("Expected InputText"),
},
_ => panic!("Expected MessageContent::Items"),
}
}

View file

@ -1,7 +1,6 @@
use crate::apis::amazon_bedrock::{ConverseOutput, ConverseResponse, StopReason};
use crate::apis::anthropic::{
MessagesContentBlock, MessagesResponse,
MessagesRole, MessagesStopReason, MessagesUsage,
MessagesContentBlock, MessagesResponse, MessagesRole, MessagesStopReason, MessagesUsage,
};
use crate::apis::openai::ChatCompletionsResponse;
use crate::clients::TransformError;
@ -115,7 +114,6 @@ impl TryFrom<ConverseResponse> for MessagesResponse {
}
}
/// Convert Bedrock Message to Anthropic content blocks
///
/// This function handles the conversion between Amazon Bedrock Converse API format

View file

@ -1,9 +1,5 @@
use crate::apis::amazon_bedrock::{
ConverseOutput, ConverseResponse, StopReason,
};
use crate::apis::anthropic::{
MessagesContentBlock, MessagesResponse, MessagesUsage,
};
use crate::apis::amazon_bedrock::{ConverseOutput, ConverseResponse, StopReason};
use crate::apis::anthropic::{MessagesContentBlock, MessagesResponse, MessagesUsage};
use crate::apis::openai::{
ChatCompletionsResponse, Choice, FinishReason, MessageContent, ResponseMessage, Role, Usage,
};
@ -16,12 +12,12 @@ use crate::transforms::lib::*;
// ============================================================================
// Usage Conversions
impl Into<Usage> for MessagesUsage {
fn into(self) -> Usage {
impl From<MessagesUsage> for Usage {
fn from(val: MessagesUsage) -> Self {
Usage {
prompt_tokens: self.input_tokens,
completion_tokens: self.output_tokens,
total_tokens: self.input_tokens + self.output_tokens,
prompt_tokens: val.input_tokens,
completion_tokens: val.output_tokens,
total_tokens: val.input_tokens + val.output_tokens,
prompt_tokens_details: None,
completion_tokens_details: None,
}
@ -203,7 +199,6 @@ impl TryFrom<ChatCompletionsResponse> for ResponsesAPIResponse {
}
}
impl TryFrom<MessagesResponse> for ChatCompletionsResponse {
type Error = TransformError;
@ -415,7 +410,6 @@ fn convert_anthropic_content_to_openai(
Ok(MessageContent::Text(text_parts.join("\n")))
}
#[cfg(test)]
mod tests {
use super::*;
@ -994,8 +988,15 @@ mod tests {
let responses_api: ResponsesAPIResponse = chat_response.try_into().unwrap();
// Response ID should be generated with resp_ prefix
assert!(responses_api.id.starts_with("resp_"), "Response ID should start with 'resp_'");
assert_eq!(responses_api.id.len(), 37, "Response ID should be resp_ + 32 char UUID");
assert!(
responses_api.id.starts_with("resp_"),
"Response ID should start with 'resp_'"
);
assert_eq!(
responses_api.id.len(),
37,
"Response ID should be resp_ + 32 char UUID"
);
assert_eq!(responses_api.object, "response");
assert_eq!(responses_api.model, "gpt-4");
@ -1008,11 +1009,7 @@ mod tests {
// Check output items
assert_eq!(responses_api.output.len(), 1);
match &responses_api.output[0] {
OutputItem::Message {
role,
content,
..
} => {
OutputItem::Message { role, content, .. } => {
assert_eq!(role, "assistant");
assert_eq!(content.len(), 1);
match &content[0] {
@ -1163,6 +1160,9 @@ mod tests {
}
// Verify status is Completed for tool_calls finish reason
assert!(matches!(responses_api.status, crate::apis::openai_responses::ResponseStatus::Completed));
assert!(matches!(
responses_api.status,
crate::apis::openai_responses::ResponseStatus::Completed
));
}
}

View file

@ -1,12 +1,9 @@
use crate::apis::amazon_bedrock::{
ContentBlockDelta, ConverseStreamEvent,
};
use crate::apis::amazon_bedrock::{ContentBlockDelta, ConverseStreamEvent};
use crate::apis::anthropic::{
MessagesContentBlock, MessagesContentDelta, MessagesMessageDelta,
MessagesRole, MessagesStopReason, MessagesStreamEvent, MessagesStreamMessage, MessagesUsage,
};
use crate::apis::openai::{ ChatCompletionsStreamResponse, ToolCallDelta,
MessagesContentBlock, MessagesContentDelta, MessagesMessageDelta, MessagesRole,
MessagesStopReason, MessagesStreamEvent, MessagesStreamMessage, MessagesUsage,
};
use crate::apis::openai::{ChatCompletionsStreamResponse, ToolCallDelta};
use crate::clients::TransformError;
use serde_json::Value;
@ -86,10 +83,10 @@ impl TryFrom<ChatCompletionsStreamResponse> for MessagesStreamEvent {
}
}
impl Into<String> for MessagesStreamEvent {
fn into(self) -> String {
let transformed_json = serde_json::to_string(&self).unwrap_or_default();
let event_type = match &self {
impl From<MessagesStreamEvent> for String {
fn from(val: MessagesStreamEvent) -> Self {
let transformed_json = serde_json::to_string(&val).unwrap_or_default();
let event_type = match &val {
MessagesStreamEvent::MessageStart { .. } => "message_start",
MessagesStreamEvent::ContentBlockStart { .. } => "content_block_start",
MessagesStreamEvent::ContentBlockDelta { .. } => "content_block_delta",
@ -194,10 +191,18 @@ impl TryFrom<ConverseStreamEvent> for MessagesStreamEvent {
let anthropic_stop_reason = match stop_event.stop_reason {
crate::apis::amazon_bedrock::StopReason::EndTurn => MessagesStopReason::EndTurn,
crate::apis::amazon_bedrock::StopReason::ToolUse => MessagesStopReason::ToolUse,
crate::apis::amazon_bedrock::StopReason::MaxTokens => MessagesStopReason::MaxTokens,
crate::apis::amazon_bedrock::StopReason::StopSequence => MessagesStopReason::EndTurn,
crate::apis::amazon_bedrock::StopReason::GuardrailIntervened => MessagesStopReason::Refusal,
crate::apis::amazon_bedrock::StopReason::ContentFiltered => MessagesStopReason::Refusal,
crate::apis::amazon_bedrock::StopReason::MaxTokens => {
MessagesStopReason::MaxTokens
}
crate::apis::amazon_bedrock::StopReason::StopSequence => {
MessagesStopReason::EndTurn
}
crate::apis::amazon_bedrock::StopReason::GuardrailIntervened => {
MessagesStopReason::Refusal
}
crate::apis::amazon_bedrock::StopReason::ContentFiltered => {
MessagesStopReason::Refusal
}
};
Ok(MessagesStreamEvent::MessageDelta {

View file

@ -1,8 +1,10 @@
use crate::apis::amazon_bedrock::{ ConverseStreamEvent, StopReason};
use crate::apis::amazon_bedrock::{ConverseStreamEvent, StopReason};
use crate::apis::anthropic::{
MessagesContentBlock, MessagesContentDelta, MessagesStopReason, MessagesStreamEvent};
use crate::apis::openai::{ ChatCompletionsStreamResponse,FinishReason,
FunctionCallDelta, MessageDelta, Role, StreamChoice, ToolCallDelta, Usage,
MessagesContentBlock, MessagesContentDelta, MessagesStopReason, MessagesStreamEvent,
};
use crate::apis::openai::{
ChatCompletionsStreamResponse, FinishReason, FunctionCallDelta, MessageDelta, Role,
StreamChoice, ToolCallDelta, Usage,
};
use crate::apis::openai_responses::ResponsesAPIStreamEvent;
@ -58,11 +60,14 @@ impl TryFrom<MessagesStreamEvent> for ChatCompletionsStreamResponse {
None,
)),
MessagesStreamEvent::ContentBlockStart { content_block, index } => {
convert_content_block_start(content_block, index)
}
MessagesStreamEvent::ContentBlockStart {
content_block,
index,
} => convert_content_block_start(content_block, index),
MessagesStreamEvent::ContentBlockDelta { delta, index } => convert_content_delta(delta, index),
MessagesStreamEvent::ContentBlockDelta { delta, index } => {
convert_content_delta(delta, index)
}
MessagesStreamEvent::ContentBlockStop { .. } => Ok(create_empty_openai_chunk()),
@ -427,9 +432,9 @@ fn create_empty_openai_chunk() -> ChatCompletionsStreamResponse {
}
// Stop Reason Conversions
impl Into<FinishReason> for MessagesStopReason {
fn into(self) -> FinishReason {
match self {
impl From<MessagesStopReason> for FinishReason {
fn from(val: MessagesStopReason) -> Self {
match val {
MessagesStopReason::EndTurn => FinishReason::Stop,
MessagesStopReason::MaxTokens => FinishReason::Length,
MessagesStopReason::StopSequence => FinishReason::Stop,
@ -456,34 +461,37 @@ impl TryFrom<ChatCompletionsStreamResponse> for ResponsesAPIStreamEvent {
if let Some(tool_call) = tool_calls.first() {
// Extract call_id and name if available (metadata from initial event)
let call_id = tool_call.id.clone();
let function_name = tool_call.function.as_ref()
.and_then(|f| f.name.clone());
let function_name = tool_call.function.as_ref().and_then(|f| f.name.clone());
// Check if we have function metadata (name, id)
if let Some(function) = &tool_call.function {
// If we have arguments delta, return that
if let Some(args) = &function.arguments {
return Ok(ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDelta {
output_index: choice.index as i32,
item_id: "".to_string(), // Buffer will fill this
delta: args.clone(),
sequence_number: 0, // Buffer will fill this
call_id,
name: function_name,
});
return Ok(
ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDelta {
output_index: choice.index as i32,
item_id: "".to_string(), // Buffer will fill this
delta: args.clone(),
sequence_number: 0, // Buffer will fill this
call_id,
name: function_name,
},
);
}
// If we have function name but no arguments yet (initial tool call event)
// Return an empty arguments delta so the buffer knows to create the item
if function.name.is_some() {
return Ok(ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDelta {
output_index: choice.index as i32,
item_id: "".to_string(), // Buffer will fill this
delta: "".to_string(), // Empty delta signals this is the initial event
sequence_number: 0, // Buffer will fill this
call_id,
name: function_name,
});
return Ok(
ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDelta {
output_index: choice.index as i32,
item_id: "".to_string(), // Buffer will fill this
delta: "".to_string(), // Empty delta signals this is the initial event
sequence_number: 0, // Buffer will fill this
call_id,
name: function_name,
},
);
}
}
}

View file

@ -94,8 +94,8 @@ impl StreamContext {
fn request_identifier(&self) -> String {
self.request_id
.as_ref()
.filter(|id| !id.is_empty()) // Filter out empty strings
.map(|id| id.clone())
.filter(|id| !id.is_empty())
.cloned()
.unwrap_or_else(|| "NO_REQUEST_ID".to_string())
}
fn llm_provider(&self) -> &LlmProvider {
@ -504,7 +504,7 @@ impl StreamContext {
// Get accumulated bytes from buffer and return
match self.sse_buffer.as_mut() {
Some(buffer) => {
let bytes = buffer.into_bytes();
let bytes = buffer.to_bytes();
if !bytes.is_empty() {
let content = String::from_utf8_lossy(&bytes);
debug!(
@ -623,7 +623,7 @@ impl StreamContext {
// Get accumulated bytes from buffer and return
match self.sse_buffer.as_mut() {
Some(buffer) => {
let bytes = buffer.into_bytes();
let bytes = buffer.to_bytes();
if !bytes.is_empty() {
let content = String::from_utf8_lossy(&bytes);
debug!(

View file

@ -142,8 +142,7 @@ impl HttpContext for StreamContext {
let last_user_prompt = match deserialized_body
.messages
.iter()
.filter(|msg| msg.role == USER_ROLE)
.last()
.rfind(|msg| msg.role == USER_ROLE)
{
Some(content) => content,
None => {
@ -155,11 +154,8 @@ impl HttpContext for StreamContext {
self.user_prompt = Some(last_user_prompt.clone());
// convert prompt targets to ChatCompletionTool
let tool_calls: Vec<ChatCompletionTool> = self
.prompt_targets
.iter()
.map(|(_, pt)| pt.into())
.collect();
let tool_calls: Vec<ChatCompletionTool> =
self.prompt_targets.values().map(|pt| pt.into()).collect();
let mut metadata = deserialized_body.metadata.clone();

View file

@ -376,21 +376,22 @@ impl StreamContext {
// Parse arguments JSON string into HashMap
// Note: convert from serde_json::Value to serde_yaml::Value for compatibility
let tool_params: Option<HashMap<String, serde_yaml::Value>> = match serde_json::from_str::<HashMap<String, serde_json::Value>>(tool_params_str) {
Ok(json_params) => {
let yaml_params: HashMap<String, serde_yaml::Value> = json_params
.into_iter()
.filter_map(|(k, v)| {
serde_yaml::to_value(&v).ok().map(|yaml_v| (k, yaml_v))
})
.collect();
Some(yaml_params)
},
Err(e) => {
warn!("Failed to parse tool call arguments: {}", e);
None
}
};
let tool_params: Option<HashMap<String, serde_yaml::Value>> =
match serde_json::from_str::<HashMap<String, serde_json::Value>>(tool_params_str) {
Ok(json_params) => {
let yaml_params: HashMap<String, serde_yaml::Value> = json_params
.into_iter()
.filter_map(|(k, v)| {
serde_yaml::to_value(&v).ok().map(|yaml_v| (k, yaml_v))
})
.collect();
Some(yaml_params)
}
Err(e) => {
warn!("Failed to parse tool call arguments: {}", e);
None
}
};
let endpoint_details = prompt_target.endpoint.as_ref().unwrap();
let endpoint_path: String = endpoint_details
@ -629,10 +630,10 @@ impl StreamContext {
}
};
if system_prompt.is_some() {
if let Some(system_prompt_text) = system_prompt {
let system_prompt_message = Message {
role: SYSTEM_ROLE.to_string(),
content: Some(ContentType::Text(system_prompt.unwrap())),
content: Some(ContentType::Text(system_prompt_text)),
model: None,
tool_calls: None,
tool_call_id: None,