cargo clippy (#660)

This commit is contained in:
Adil Hafeez 2025-12-25 21:08:37 -08:00 committed by GitHub
parent c75e7606f9
commit ca95ffb63d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
62 changed files with 1864 additions and 1187 deletions

View file

@ -13,19 +13,22 @@ repos:
name: cargo-fmt name: cargo-fmt
language: system language: system
types: [file, rust] types: [file, rust]
entry: bash -c "cd crates/llm_gateway && cargo fmt" entry: bash -c "cd crates && cargo fmt --all -- --check"
pass_filenames: false
- id: cargo-clippy - id: cargo-clippy
name: cargo-clippy name: cargo-clippy
language: system language: system
types: [file, rust] types: [file, rust]
entry: bash -c "cd crates/llm_gateway && cargo clippy --all" entry: bash -c "cd crates && cargo clippy --locked --offline --all-targets --all-features -- -D warnings || cargo clippy --locked --all-targets --all-features -- -D warnings"
pass_filenames: false
- id: cargo-test - id: cargo-test
name: cargo-test name: cargo-test
language: system language: system
types: [file, rust] types: [file, rust]
entry: bash -c "cd crates && cargo test --lib" entry: bash -c "cd crates && cargo test --lib"
pass_filenames: false
- repo: https://github.com/psf/black - repo: https://github.com/psf/black
rev: 23.1.0 rev: 23.1.0

View file

@ -3,7 +3,7 @@ use std::time::{Instant, SystemTime};
use bytes::Bytes; use bytes::Bytes;
use common::consts::TRACE_PARENT_HEADER; use common::consts::TRACE_PARENT_HEADER;
use common::traces::{SpanBuilder, SpanKind, parse_traceparent, generate_random_span_id}; use common::traces::{generate_random_span_id, parse_traceparent, SpanBuilder, SpanKind};
use hermesllm::apis::OpenAIMessage; use hermesllm::apis::OpenAIMessage;
use hermesllm::clients::SupportedAPIsFromClient; use hermesllm::clients::SupportedAPIsFromClient;
use hermesllm::providers::request::ProviderRequest; use hermesllm::providers::request::ProviderRequest;
@ -18,7 +18,7 @@ use super::agent_selector::{AgentSelectionError, AgentSelector};
use super::pipeline_processor::{PipelineError, PipelineProcessor}; use super::pipeline_processor::{PipelineError, PipelineProcessor};
use super::response_handler::ResponseHandler; use super::response_handler::ResponseHandler;
use crate::router::plano_orchestrator::OrchestratorService; use crate::router::plano_orchestrator::OrchestratorService;
use crate::tracing::{OperationNameBuilder, operation_component, http}; use crate::tracing::{http, operation_component, OperationNameBuilder};
/// Main errors for agent chat completions /// Main errors for agent chat completions
#[derive(Debug, thiserror::Error)] #[derive(Debug, thiserror::Error)]
@ -61,7 +61,6 @@ pub async fn agent_chat(
body, body,
}) = &err }) = &err
{ {
warn!( warn!(
"Client error from agent '{}' (HTTP {}): {}", "Client error from agent '{}' (HTTP {}): {}",
agent, status, body agent, status, body
@ -77,8 +76,8 @@ pub async fn agent_chat(
let json_string = error_json.to_string(); let json_string = error_json.to_string();
let mut response = Response::new(ResponseHandler::create_full_body(json_string)); let mut response = Response::new(ResponseHandler::create_full_body(json_string));
*response.status_mut() = hyper::StatusCode::from_u16(*status) *response.status_mut() =
.unwrap_or(hyper::StatusCode::BAD_REQUEST); hyper::StatusCode::from_u16(*status).unwrap_or(hyper::StatusCode::BAD_REQUEST);
response.headers_mut().insert( response.headers_mut().insert(
hyper::header::CONTENT_TYPE, hyper::header::CONTENT_TYPE,
"application/json".parse().unwrap(), "application/json".parse().unwrap(),
@ -234,8 +233,18 @@ async fn handle_agent_chat(
.with_attribute(http::TARGET, "/agents/select") .with_attribute(http::TARGET, "/agents/select")
.with_attribute("selection.listener", listener.name.clone()) .with_attribute("selection.listener", listener.name.clone())
.with_attribute("selection.agent_count", selected_agents.len().to_string()) .with_attribute("selection.agent_count", selected_agents.len().to_string())
.with_attribute("selection.agents", selected_agents.iter().map(|a| a.id.as_str()).collect::<Vec<_>>().join(",")) .with_attribute(
.with_attribute("duration_ms", format!("{:.2}", selection_elapsed.as_secs_f64() * 1000.0)); "selection.agents",
selected_agents
.iter()
.map(|a| a.id.as_str())
.collect::<Vec<_>>()
.join(","),
)
.with_attribute(
"duration_ms",
format!("{:.2}", selection_elapsed.as_secs_f64() * 1000.0),
);
if !trace_id.is_empty() { if !trace_id.is_empty() {
selection_span_builder = selection_span_builder.with_trace_id(trace_id.clone()); selection_span_builder = selection_span_builder.with_trace_id(trace_id.clone());
@ -318,8 +327,14 @@ async fn handle_agent_chat(
.with_attribute(http::METHOD, "POST") .with_attribute(http::METHOD, "POST")
.with_attribute(http::TARGET, full_path) .with_attribute(http::TARGET, full_path)
.with_attribute("agent.name", agent_name.clone()) .with_attribute("agent.name", agent_name.clone())
.with_attribute("agent.sequence", format!("{}/{}", agent_index + 1, agent_count)) .with_attribute(
.with_attribute("duration_ms", format!("{:.2}", agent_elapsed.as_secs_f64() * 1000.0)); "agent.sequence",
format!("{}/{}", agent_index + 1, agent_count),
)
.with_attribute(
"duration_ms",
format!("{:.2}", agent_elapsed.as_secs_f64() * 1000.0),
);
if !trace_id.is_empty() { if !trace_id.is_empty() {
span_builder = span_builder.with_trace_id(trace_id.clone()); span_builder = span_builder.with_trace_id(trace_id.clone());
@ -333,7 +348,10 @@ async fn handle_agent_chat(
// If this is the last agent, return the streaming response // If this is the last agent, return the streaming response
if is_last_agent { if is_last_agent {
info!("Completed agent chain, returning response from last agent: {}", agent_name); info!(
"Completed agent chain, returning response from last agent: {}",
agent_name
);
return response_handler return response_handler
.create_streaming_response(llm_response) .create_streaming_response(llm_response)
.await .await
@ -341,7 +359,10 @@ async fn handle_agent_chat(
} }
// For intermediate agents, collect the full response and pass to next agent // For intermediate agents, collect the full response and pass to next agent
debug!("Collecting response from intermediate agent: {}", agent_name); debug!(
"Collecting response from intermediate agent: {}",
agent_name
);
let response_text = response_handler.collect_full_response(llm_response).await?; let response_text = response_handler.collect_full_response(llm_response).await?;
info!( info!(
@ -364,7 +385,6 @@ async fn handle_agent_chat(
}); });
current_messages.push(last_message); current_messages.push(last_message);
} }
// This should never be reached since we return in the last agent iteration // This should never be reached since we return in the last agent iteration

View file

@ -2,7 +2,7 @@ use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
use common::configuration::{ use common::configuration::{
Agent, AgentFilterChain, Listener, AgentUsagePreference, OrchestrationPreference, Agent, AgentFilterChain, AgentUsagePreference, Listener, OrchestrationPreference,
}; };
use hermesllm::apis::openai::Message; use hermesllm::apis::openai::Message;
use tracing::{debug, warn}; use tracing::{debug, warn};

View file

@ -1,20 +1,18 @@
use bytes::Bytes;
use eventsource_stream::Eventsource;
use futures::StreamExt;
use hermesllm::apis::openai::{ use hermesllm::apis::openai::{
ChatCompletionsRequest, ChatCompletionsResponse, Choice, FinishReason, FunctionCall, Message, ChatCompletionsRequest, ChatCompletionsResponse, Choice, FinishReason, FunctionCall, Message,
MessageContent, ResponseMessage, Role, Tool, ToolCall, Usage, MessageContent, ResponseMessage, Role, Tool, ToolCall, Usage,
}; };
use http_body_util::{combinators::BoxBody, BodyExt, Full};
use hyper::body::Incoming;
use hyper::{Request, Response, StatusCode};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::{json, Value}; use serde_json::{json, Value};
use std::collections::HashMap; use std::collections::HashMap;
use thiserror::Error; use thiserror::Error;
use tracing::{info, error}; use tracing::{error, info};
use futures::StreamExt;
use bytes::Bytes;
use http_body_util::{combinators::BoxBody, BodyExt, Full};
use hyper::body::Incoming;
use hyper::{Request, Response, StatusCode};
use eventsource_stream::Eventsource;
// ============================================================================ // ============================================================================
// CONSTANTS FOR HALLUCINATION DETECTION // CONSTANTS FOR HALLUCINATION DETECTION
@ -273,17 +271,14 @@ impl ArchFunctionHandler {
let mut stack: Vec<char> = Vec::new(); let mut stack: Vec<char> = Vec::new();
let mut fixed_str = String::new(); let mut fixed_str = String::new();
let matching_bracket: HashMap<char, char> = let matching_bracket: HashMap<char, char> = [(')', '('), ('}', '{'), (']', '[')]
[(')', '('), ('}', '{'), (']', '[')]
.iter()
.cloned()
.collect();
let opening_bracket: HashMap<char, char> = matching_bracket
.iter() .iter()
.map(|(k, v)| (*v, *k)) .cloned()
.collect(); .collect();
let opening_bracket: HashMap<char, char> =
matching_bracket.iter().map(|(k, v)| (*v, *k)).collect();
for ch in json_str.chars() { for ch in json_str.chars() {
if ch == '{' || ch == '[' || ch == '(' { if ch == '{' || ch == '[' || ch == '(' {
stack.push(ch); stack.push(ch);
@ -332,12 +327,18 @@ impl ArchFunctionHandler {
// Remove markdown code blocks // Remove markdown code blocks
let mut content = content.trim().to_string(); let mut content = content.trim().to_string();
if content.starts_with("```") && content.ends_with("```") { if content.starts_with("```") && content.ends_with("```") {
content = content.trim_start_matches("```").trim_end_matches("```").to_string(); content = content
.trim_start_matches("```")
.trim_end_matches("```")
.to_string();
if content.starts_with("json") { if content.starts_with("json") {
content = content.trim_start_matches("json").to_string(); content = content.trim_start_matches("json").to_string();
} }
// Trim again after removing code blocks to eliminate internal whitespace // Trim again after removing code blocks to eliminate internal whitespace
content = content.trim_start_matches(r"\n").trim_end_matches(r"\n").to_string(); content = content
.trim_start_matches(r"\n")
.trim_end_matches(r"\n")
.to_string();
content = content.trim().to_string(); content = content.trim().to_string();
// Unescape the quotes: \" -> " // Unescape the quotes: \" -> "
// The model sometimes returns escaped JSON inside markdown blocks // The model sometimes returns escaped JSON inside markdown blocks
@ -453,12 +454,12 @@ impl ArchFunctionHandler {
/// Helper method to check if a value matches the expected type /// Helper method to check if a value matches the expected type
fn check_value_type(&self, value: &Value, target_type: &str) -> bool { fn check_value_type(&self, value: &Value, target_type: &str) -> bool {
match target_type { match target_type {
"int" | "integer" => value.is_i64() || value.is_u64(), "int" | "integer" => value.is_i64() || value.is_u64(),
"float" | "number" => value.is_f64() || value.is_i64() || value.is_u64(), "float" | "number" => value.is_f64() || value.is_i64() || value.is_u64(),
"bool" | "boolean" => value.is_boolean(), "bool" | "boolean" => value.is_boolean(),
"str" | "string" => value.is_string(), "str" | "string" => value.is_string(),
"list" | "array" => value.is_array(), "list" | "array" => value.is_array(),
"dict" | "object" => value.is_object(), "dict" | "object" => value.is_object(),
_ => true, _ => true,
} }
} }
@ -505,15 +506,19 @@ impl ArchFunctionHandler {
let func_name = &tool_call.function.name; let func_name = &tool_call.function.name;
// Parse arguments as JSON // Parse arguments as JSON
let func_args: HashMap<String, Value> = match serde_json::from_str(&tool_call.function.arguments) { let func_args: HashMap<String, Value> =
Ok(args) => args, match serde_json::from_str(&tool_call.function.arguments) {
Err(e) => { Ok(args) => args,
verification.is_valid = false; Err(e) => {
verification.invalid_tool_call = Some(tool_call.clone()); verification.is_valid = false;
verification.error_message = format!("Failed to parse arguments for function '{}': {}", func_name, e); verification.invalid_tool_call = Some(tool_call.clone());
break; verification.error_message = format!(
} "Failed to parse arguments for function '{}': {}",
}; func_name, e
);
break;
}
};
// Check if function is available // Check if function is available
if let Some(function_params) = functions.get(func_name) { if let Some(function_params) = functions.get(func_name) {
@ -541,14 +546,23 @@ impl ArchFunctionHandler {
if let Some(properties_obj) = properties.as_object() { if let Some(properties_obj) = properties.as_object() {
for (param_name, param_value) in &func_args { for (param_name, param_value) in &func_args {
if let Some(param_schema) = properties_obj.get(param_name) { if let Some(param_schema) = properties_obj.get(param_name) {
if let Some(target_type) = param_schema.get("type").and_then(|v| v.as_str()) { if let Some(target_type) =
if self.config.support_data_types.contains(&target_type.to_string()) { param_schema.get("type").and_then(|v| v.as_str())
{
if self
.config
.support_data_types
.contains(&target_type.to_string())
{
// Validate data type using helper method // Validate data type using helper method
match self.validate_or_convert_parameter(param_value, target_type) { match self
.validate_or_convert_parameter(param_value, target_type)
{
Ok(is_valid) => { Ok(is_valid) => {
if !is_valid { if !is_valid {
verification.is_valid = false; verification.is_valid = false;
verification.invalid_tool_call = Some(tool_call.clone()); verification.invalid_tool_call =
Some(tool_call.clone());
verification.error_message = format!( verification.error_message = format!(
"Parameter `{}` is expected to have the data type `{}`, got incompatible type.", "Parameter `{}` is expected to have the data type `{}`, got incompatible type.",
param_name, target_type param_name, target_type
@ -558,7 +572,8 @@ impl ArchFunctionHandler {
} }
Err(_) => { Err(_) => {
verification.is_valid = false; verification.is_valid = false;
verification.invalid_tool_call = Some(tool_call.clone()); verification.invalid_tool_call =
Some(tool_call.clone());
verification.error_message = format!( verification.error_message = format!(
"Parameter `{}` is expected to have the data type `{}`, got incompatible type.", "Parameter `{}` is expected to have the data type `{}`, got incompatible type.",
param_name, target_type param_name, target_type
@ -569,7 +584,10 @@ impl ArchFunctionHandler {
} else { } else {
verification.is_valid = false; verification.is_valid = false;
verification.invalid_tool_call = Some(tool_call.clone()); verification.invalid_tool_call = Some(tool_call.clone());
verification.error_message = format!("Data type `{}` is not supported.", target_type); verification.error_message = format!(
"Data type `{}` is not supported.",
target_type
);
break; break;
} }
} }
@ -598,11 +616,8 @@ impl ArchFunctionHandler {
/// Formats the system prompt with tools /// Formats the system prompt with tools
pub fn format_system_prompt(&self, tools: &[Tool]) -> Result<String> { pub fn format_system_prompt(&self, tools: &[Tool]) -> Result<String> {
let tools_str = self.convert_tools(tools)?; let tools_str = self.convert_tools(tools)?;
let system_prompt = self let system_prompt =
.config self.config.task_prompt.replace("{tools}", &tools_str) + &self.config.format_prompt;
.task_prompt
.replace("{tools}", &tools_str)
+ &self.config.format_prompt;
Ok(system_prompt) Ok(system_prompt)
} }
@ -665,15 +680,22 @@ impl ArchFunctionHandler {
// Strip markdown code blocks // Strip markdown code blocks
if tool_call_msg.starts_with("```") && tool_call_msg.ends_with("```") { if tool_call_msg.starts_with("```") && tool_call_msg.ends_with("```") {
tool_call_msg = tool_call_msg.trim_start_matches("```").trim_end_matches("```").trim().to_string(); tool_call_msg = tool_call_msg
.trim_start_matches("```")
.trim_end_matches("```")
.trim()
.to_string();
if tool_call_msg.starts_with("json") { if tool_call_msg.starts_with("json") {
tool_call_msg = tool_call_msg.trim_start_matches("json").trim().to_string(); tool_call_msg =
tool_call_msg.trim_start_matches("json").trim().to_string();
} }
} }
// Extract function name // Extract function name
if let Ok(parsed) = serde_json::from_str::<Value>(&tool_call_msg) { if let Ok(parsed) = serde_json::from_str::<Value>(&tool_call_msg) {
if let Some(tool_calls_arr) = parsed.get("tool_calls").and_then(|v| v.as_array()) { if let Some(tool_calls_arr) =
parsed.get("tool_calls").and_then(|v| v.as_array())
{
if let Some(first_tool_call) = tool_calls_arr.first() { if let Some(first_tool_call) = tool_calls_arr.first() {
let func_name = first_tool_call let func_name = first_tool_call
.get("name") .get("name")
@ -685,8 +707,10 @@ impl ArchFunctionHandler {
"result": content, "result": content,
}); });
content = format!("<tool_response>\n{}\n</tool_response>", content = format!(
serde_json::to_string(&tool_response)?); "<tool_response>\n{}\n</tool_response>",
serde_json::to_string(&tool_response)?
);
} }
} }
} }
@ -717,7 +741,7 @@ impl ArchFunctionHandler {
if let Some(instruction) = extra_instruction { if let Some(instruction) = extra_instruction {
if let Some(last) = processed_messages.last_mut() { if let Some(last) = processed_messages.last_mut() {
if let MessageContent::Text(content) = &mut last.content { if let MessageContent::Text(content) = &mut last.content {
content.push_str("\n"); content.push('\n');
content.push_str(instruction); content.push_str(instruction);
} }
} }
@ -750,13 +774,11 @@ impl ArchFunctionHandler {
for i in (conversation_idx..messages.len()).rev() { for i in (conversation_idx..messages.len()).rev() {
if let MessageContent::Text(content) = &messages[i].content { if let MessageContent::Text(content) = &messages[i].content {
num_tokens += content.len() / 4; num_tokens += content.len() / 4;
if num_tokens >= max_tokens { if num_tokens >= max_tokens && messages[i].role == Role::User {
if messages[i].role == Role::User { // Set message_idx to current position and break
// Set message_idx to current position and break // This matches Python's behavior where message_idx is set before break
// This matches Python's behavior where message_idx is set before break message_idx = i;
message_idx = i; break;
break;
}
} }
} }
// Only update message_idx if we haven't hit the token limit yet // Only update message_idx if we haven't hit the token limit yet
@ -789,7 +811,11 @@ impl ArchFunctionHandler {
} }
/// Helper to create a request with VLLM-specific parameters /// Helper to create a request with VLLM-specific parameters
fn create_request_with_extra_body(&self, messages: Vec<Message>, stream: bool) -> ChatCompletionsRequest { fn create_request_with_extra_body(
&self,
messages: Vec<Message>,
stream: bool,
) -> ChatCompletionsRequest {
ChatCompletionsRequest { ChatCompletionsRequest {
model: self.model_name.clone(), model: self.model_name.clone(),
messages, messages,
@ -813,24 +839,38 @@ impl ArchFunctionHandler {
} }
/// Makes a streaming request and returns the SSE event stream /// Makes a streaming request and returns the SSE event stream
async fn make_streaming_request(&self, request: ChatCompletionsRequest) -> Result<std::pin::Pin<Box<dyn futures::Stream<Item = std::result::Result<Value, String>> + Send>>> { async fn make_streaming_request(
let request_body = serde_json::to_string(&request) &self,
.map_err(|e| FunctionCallingError::InvalidModelResponse(format!("Failed to serialize request: {}", e)))?; request: ChatCompletionsRequest,
) -> Result<
std::pin::Pin<Box<dyn futures::Stream<Item = std::result::Result<Value, String>> + Send>>,
> {
let request_body = serde_json::to_string(&request).map_err(|e| {
FunctionCallingError::InvalidModelResponse(format!(
"Failed to serialize request: {}",
e
))
})?;
let response = self.http_client let response = self
.http_client
.post(&self.endpoint_url) .post(&self.endpoint_url)
.header("Content-Type", "application/json") .header("Content-Type", "application/json")
.body(request_body) .body(request_body)
.send() .send()
.await .await
.map_err(|e| FunctionCallingError::HttpError(e))?; .map_err(FunctionCallingError::HttpError)?;
if !response.status().is_success() { if !response.status().is_success() {
let status = response.status(); let status = response.status();
let error_text = response.text().await.unwrap_or_else(|_| "Unknown error".to_string()); let error_text = response
return Err(FunctionCallingError::InvalidModelResponse( .text()
format!("HTTP error {}: {}", status, error_text) .await
)); .unwrap_or_else(|_| "Unknown error".to_string());
return Err(FunctionCallingError::InvalidModelResponse(format!(
"HTTP error {}: {}",
status, error_text
)));
} }
// Parse SSE stream // Parse SSE stream
@ -856,38 +896,51 @@ impl ArchFunctionHandler {
} }
/// Makes a non-streaming request and returns the response /// Makes a non-streaming request and returns the response
async fn make_non_streaming_request(&self, request: ChatCompletionsRequest) -> Result<ChatCompletionsResponse> { async fn make_non_streaming_request(
let request_body = serde_json::to_string(&request) &self,
.map_err(|e| FunctionCallingError::InvalidModelResponse(format!("Failed to serialize request: {}", e)))?; request: ChatCompletionsRequest,
) -> Result<ChatCompletionsResponse> {
let request_body = serde_json::to_string(&request).map_err(|e| {
FunctionCallingError::InvalidModelResponse(format!(
"Failed to serialize request: {}",
e
))
})?;
let response = self.http_client let response = self
.http_client
.post(&self.endpoint_url) .post(&self.endpoint_url)
.header("Content-Type", "application/json") .header("Content-Type", "application/json")
.body(request_body) .body(request_body)
.send() .send()
.await .await
.map_err(|e| FunctionCallingError::HttpError(e))?; .map_err(FunctionCallingError::HttpError)?;
if !response.status().is_success() { if !response.status().is_success() {
let status = response.status(); let status = response.status();
let error_text = response.text().await.unwrap_or_else(|_| "Unknown error".to_string()); let error_text = response
return Err(FunctionCallingError::InvalidModelResponse( .text()
format!("HTTP error {}: {}", status, error_text) .await
)); .unwrap_or_else(|_| "Unknown error".to_string());
return Err(FunctionCallingError::InvalidModelResponse(format!(
"HTTP error {}: {}",
status, error_text
)));
} }
let response_text = response.text().await let response_text = response
.map_err(|e| FunctionCallingError::HttpError(e))?; .text()
.await
.map_err(FunctionCallingError::HttpError)?;
serde_json::from_str(&response_text) serde_json::from_str(&response_text).map_err(FunctionCallingError::JsonParseError)
.map_err(|e| FunctionCallingError::JsonParseError(e))
} }
pub async fn function_calling_chat( pub async fn function_calling_chat(
&self, &self,
request: ChatCompletionsRequest, request: ChatCompletionsRequest,
) -> Result<ChatCompletionsResponse> { ) -> Result<ChatCompletionsResponse> {
use tracing::{info, error}; use tracing::{error, info};
info!("[Arch-Function] - ChatCompletion"); info!("[Arch-Function] - ChatCompletion");
@ -899,10 +952,14 @@ impl ArchFunctionHandler {
request.metadata.as_ref(), request.metadata.as_ref(),
)?; )?;
info!("[request to arch-fc]: model: {}, messages count: {}", info!(
self.model_name, messages.len()); "[request to arch-fc]: model: {}, messages count: {}",
self.model_name,
messages.len()
);
let use_agent_orchestrator = request.metadata let use_agent_orchestrator = request
.metadata
.as_ref() .as_ref()
.and_then(|m| m.get("use_agent_orchestrator")) .and_then(|m| m.get("use_agent_orchestrator"))
.and_then(|v| v.as_bool()) .and_then(|v| v.as_bool())
@ -918,89 +975,95 @@ impl ArchFunctionHandler {
if use_agent_orchestrator { if use_agent_orchestrator {
while let Some(chunk_result) = stream.next().await { while let Some(chunk_result) = stream.next().await {
let chunk = chunk_result.map_err(|e| FunctionCallingError::InvalidModelResponse(e))?; let chunk = chunk_result.map_err(FunctionCallingError::InvalidModelResponse)?;
// Extract content from JSON response // Extract content from JSON response
if let Some(choices) = chunk.get("choices").and_then(|v| v.as_array()) { if let Some(choices) = chunk.get("choices").and_then(|v| v.as_array()) {
if let Some(choice) = choices.first() { if let Some(choice) = choices.first() {
if let Some(content) = choice.get("delta") if let Some(content) = choice
.get("delta")
.and_then(|d| d.get("content")) .and_then(|d| d.get("content"))
.and_then(|c| c.as_str()) { .and_then(|c| c.as_str())
{
model_response.push_str(content); model_response.push_str(content);
} }
} }
} }
} }
info!("[Agent Orchestrator]: response received"); info!("[Agent Orchestrator]: response received");
} else { } else if let Some(tools) = request.tools.as_ref() {
if let Some(tools) = request.tools.as_ref() { let mut hallucination_state = HallucinationState::new(tools);
let mut hallucination_state = HallucinationState::new(tools); let mut has_tool_calls = None;
let mut has_tool_calls = None; let mut has_hallucination = false;
let mut has_hallucination = false;
while let Some(chunk_result) = stream.next().await { while let Some(chunk_result) = stream.next().await {
let chunk = chunk_result.map_err(|e| FunctionCallingError::InvalidModelResponse(e))?; let chunk = chunk_result.map_err(FunctionCallingError::InvalidModelResponse)?;
// Extract content and logprobs from JSON response // Extract content and logprobs from JSON response
if let Some(choices) = chunk.get("choices").and_then(|v| v.as_array()) { if let Some(choices) = chunk.get("choices").and_then(|v| v.as_array()) {
if let Some(choice) = choices.first() { if let Some(choice) = choices.first() {
if let Some(content) = choice.get("delta") if let Some(content) = choice
.and_then(|d| d.get("content")) .get("delta")
.and_then(|c| c.as_str()) { .and_then(|d| d.get("content"))
.and_then(|c| c.as_str())
{
// Extract logprobs
let logprobs: Vec<f64> = choice
.get("logprobs")
.and_then(|lp| lp.get("content"))
.and_then(|c| c.as_array())
.and_then(|arr| arr.first())
.and_then(|token| token.get("top_logprobs"))
.and_then(|tlp| tlp.as_array())
.map(|arr| {
arr.iter()
.filter_map(|v| v.get("logprob").and_then(|lp| lp.as_f64()))
.collect()
})
.unwrap_or_default();
// Extract logprobs if hallucination_state
let logprobs: Vec<f64> = choice.get("logprobs") .append_and_check_token_hallucination(content.to_string(), logprobs)
.and_then(|lp| lp.get("content")) {
.and_then(|c| c.as_array()) has_hallucination = true;
.and_then(|arr| arr.first()) break;
.and_then(|token| token.get("top_logprobs")) }
.and_then(|tlp| tlp.as_array())
.map(|arr| {
arr.iter()
.filter_map(|v| v.get("logprob").and_then(|lp| lp.as_f64()))
.collect()
})
.unwrap_or_default();
if hallucination_state.append_and_check_token_hallucination(content.to_string(), logprobs) { if hallucination_state.tokens.len() > 5 && has_tool_calls.is_none() {
has_hallucination = true; let collected_content = hallucination_state.tokens.join("");
break; has_tool_calls = Some(collected_content.contains("tool_calls"));
}
if hallucination_state.tokens.len() > 5 && has_tool_calls.is_none() {
let collected_content = hallucination_state.tokens.join("");
has_tool_calls = Some(collected_content.contains("tool_calls"));
}
} }
} }
} }
} }
}
if has_tool_calls == Some(true) && has_hallucination { if has_tool_calls == Some(true) && has_hallucination {
info!("[Hallucination]: {}", hallucination_state.error_message); info!("[Hallucination]: {}", hallucination_state.error_message);
let clarify_messages = self.prefill_message(messages.clone(), &self.clarify_prefix); let clarify_messages = self.prefill_message(messages.clone(), &self.clarify_prefix);
let clarify_request = self.create_request_with_extra_body(clarify_messages, false); let clarify_request = self.create_request_with_extra_body(clarify_messages, false);
let retry_response = self.make_non_streaming_request(clarify_request).await?; let retry_response = self.make_non_streaming_request(clarify_request).await?;
if let Some(choice) = retry_response.choices.first() { if let Some(choice) = retry_response.choices.first() {
if let Some(content) = &choice.message.content { if let Some(content) = &choice.message.content {
model_response = content.clone(); model_response = content.clone();
}
} }
} else {
model_response = hallucination_state.tokens.join("");
} }
} else { } else {
while let Some(chunk_result) = stream.next().await { model_response = hallucination_state.tokens.join("");
let chunk = chunk_result.map_err(|e| FunctionCallingError::InvalidModelResponse(e))?; }
if let Some(choices) = chunk.get("choices").and_then(|v| v.as_array()) { } else {
if let Some(choice) = choices.first() { while let Some(chunk_result) = stream.next().await {
if let Some(content) = choice.get("delta") let chunk = chunk_result.map_err(FunctionCallingError::InvalidModelResponse)?;
.and_then(|d| d.get("content")) if let Some(choices) = chunk.get("choices").and_then(|v| v.as_array()) {
.and_then(|c| c.as_str()) { if let Some(choice) = choices.first() {
model_response.push_str(content); if let Some(content) = choice
} .get("delta")
.and_then(|d| d.get("content"))
.and_then(|c| c.as_str())
{
model_response.push_str(content);
} }
} }
} }
@ -1009,10 +1072,17 @@ impl ArchFunctionHandler {
let response_dict = self.parse_model_response(&model_response); let response_dict = self.parse_model_response(&model_response);
info!("[arch-fc]: raw model response: {}", response_dict.raw_response); info!(
"[arch-fc]: raw model response: {}",
response_dict.raw_response
);
// General model response (no intent matched - should route to default target) // General model response (no intent matched - should route to default target)
let model_message = if response_dict.response.as_ref().map_or(false, |s| !s.is_empty()) { let model_message = if response_dict
.response
.as_ref()
.is_some_and(|s| !s.is_empty())
{
// When arch-fc returns a "response" field, it means no intent was matched // When arch-fc returns a "response" field, it means no intent was matched
// Return empty content and empty tool_calls so prompt_gateway routes to default target // Return empty content and empty tool_calls so prompt_gateway routes to default target
ResponseMessage { ResponseMessage {
@ -1053,8 +1123,11 @@ impl ArchFunctionHandler {
let verification = self.verify_tool_calls(tools, &response_dict.tool_calls); let verification = self.verify_tool_calls(tools, &response_dict.tool_calls);
if verification.is_valid { if verification.is_valid {
info!("[Tool calls]: {:?}", info!(
response_dict.tool_calls.iter() "[Tool calls]: {:?}",
response_dict
.tool_calls
.iter()
.map(|tc| &tc.function) .map(|tc| &tc.function)
.collect::<Vec<_>>() .collect::<Vec<_>>()
); );
@ -1092,8 +1165,11 @@ impl ArchFunctionHandler {
} }
} }
} else { } else {
info!("[Tool calls]: {:?}", info!(
response_dict.tool_calls.iter() "[Tool calls]: {:?}",
response_dict
.tool_calls
.iter()
.map(|tc| &tc.function) .map(|tc| &tc.function)
.collect::<Vec<_>>() .collect::<Vec<_>>()
); );
@ -1108,7 +1184,10 @@ impl ArchFunctionHandler {
} }
} }
} else { } else {
error!("Invalid tool calls in response: {}", response_dict.error_message); error!(
"Invalid tool calls in response: {}",
response_dict.error_message
);
ResponseMessage { ResponseMessage {
role: Role::Assistant, role: Role::Assistant,
content: Some(String::new()), content: Some(String::new()),
@ -1243,7 +1322,6 @@ pub async fn function_calling_chat_handler(
req: Request<Incoming>, req: Request<Incoming>,
llm_provider_url: String, llm_provider_url: String,
) -> std::result::Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> { ) -> std::result::Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
use hermesllm::apis::openai::ChatCompletionsRequest; use hermesllm::apis::openai::ChatCompletionsRequest;
let whole_body = req.collect().await?.to_bytes(); let whole_body = req.collect().await?.to_bytes();
@ -1255,10 +1333,13 @@ pub async fn function_calling_chat_handler(
let mut response = Response::new(full( let mut response = Response::new(full(
serde_json::json!({ serde_json::json!({
"error": format!("Invalid request body: {}", e) "error": format!("Invalid request body: {}", e)
}).to_string() })
.to_string(),
)); ));
*response.status_mut() = StatusCode::BAD_REQUEST; *response.status_mut() = StatusCode::BAD_REQUEST;
response.headers_mut().insert("Content-Type", "application/json".parse().unwrap()); response
.headers_mut()
.insert("Content-Type", "application/json".parse().unwrap());
return Ok(response); return Ok(response);
} }
}; };
@ -1271,24 +1352,31 @@ pub async fn function_calling_chat_handler(
// Parse as ChatCompletionsRequest // Parse as ChatCompletionsRequest
let chat_request: ChatCompletionsRequest = match serde_json::from_value(body_json) { let chat_request: ChatCompletionsRequest = match serde_json::from_value(body_json) {
Ok(req) => { Ok(req) => {
info!("[request body]: {}", serde_json::to_string(&req).unwrap_or_default()); info!(
"[request body]: {}",
serde_json::to_string(&req).unwrap_or_default()
);
req req
}, }
Err(e) => { Err(e) => {
error!("Failed to parse request body: {}", e); error!("Failed to parse request body: {}", e);
let mut response = Response::new(full( let mut response = Response::new(full(
serde_json::json!({ serde_json::json!({
"error": format!("Invalid request body: {}", e) "error": format!("Invalid request body: {}", e)
}).to_string() })
.to_string(),
)); ));
*response.status_mut() = StatusCode::BAD_REQUEST; *response.status_mut() = StatusCode::BAD_REQUEST;
response.headers_mut().insert("Content-Type", "application/json".parse().unwrap()); response
.headers_mut()
.insert("Content-Type", "application/json".parse().unwrap());
return Ok(response); return Ok(response);
} }
}; };
// Determine which handler to use based on metadata // Determine which handler to use based on metadata
let use_agent_orchestrator = chat_request.metadata let use_agent_orchestrator = chat_request
.metadata
.as_ref() .as_ref()
.and_then(|m| m.get("use_agent_orchestrator")) .and_then(|m| m.get("use_agent_orchestrator"))
.and_then(|v| v.as_bool()) .and_then(|v| v.as_bool())
@ -1309,7 +1397,10 @@ pub async fn function_calling_chat_handler(
ARCH_FUNCTION_MODEL_NAME.to_string(), ARCH_FUNCTION_MODEL_NAME.to_string(),
llm_provider_url.clone(), llm_provider_url.clone(),
); );
handler.function_handler.function_calling_chat(chat_request).await handler
.function_handler
.function_calling_chat(chat_request)
.await
} else { } else {
let handler = ArchFunctionHandler::new( let handler = ArchFunctionHandler::new(
ARCH_FUNCTION_MODEL_NAME.to_string(), ARCH_FUNCTION_MODEL_NAME.to_string(),
@ -1328,7 +1419,9 @@ pub async fn function_calling_chat_handler(
let mut response = Response::new(full(response_json)); let mut response = Response::new(full(response_json));
*response.status_mut() = StatusCode::OK; *response.status_mut() = StatusCode::OK;
response.headers_mut().insert("Content-Type", "application/json".parse().unwrap()); response
.headers_mut()
.insert("Content-Type", "application/json".parse().unwrap());
Ok(response) Ok(response)
} }
@ -1341,13 +1434,14 @@ pub async fn function_calling_chat_handler(
let mut response = Response::new(full(error_response.to_string())); let mut response = Response::new(full(error_response.to_string()));
*response.status_mut() = StatusCode::INTERNAL_SERVER_ERROR; *response.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
response.headers_mut().insert("Content-Type", "application/json".parse().unwrap()); response
.headers_mut()
.insert("Content-Type", "application/json".parse().unwrap());
Ok(response) Ok(response)
} }
} }
} }
// ============================================================================ // ============================================================================
// TESTS // TESTS
// ============================================================================ // ============================================================================
@ -1370,10 +1464,13 @@ mod tests {
assert!(config.task_prompt.contains("</tools>\\n\\n")); assert!(config.task_prompt.contains("</tools>\\n\\n"));
// Format prompt should contain literal escaped newlines and proper JSON examples // Format prompt should contain literal escaped newlines and proper JSON examples
assert!(config.format_prompt.contains("\\n\\nBased on your analysis")); assert!(config
assert!(config.format_prompt.contains(r#"{\"response\": \"Your response text here\"}"#)); .format_prompt
.contains("\\n\\nBased on your analysis"));
assert!(config
.format_prompt
.contains(r#"{\"response\": \"Your response text here\"}"#));
assert!(config.format_prompt.contains(r#"{\"tool_calls\": [{"#)); assert!(config.format_prompt.contains(r#"{\"tool_calls\": [{"#));
} }
#[test] #[test]
@ -1384,7 +1481,11 @@ mod tests {
#[test] #[test]
fn test_fix_json_string_valid() { fn test_fix_json_string_valid() {
let handler = ArchFunctionHandler::new("test-model".to_string(), ArchFunctionConfig::default(), "http://localhost:8000".to_string()); let handler = ArchFunctionHandler::new(
"test-model".to_string(),
ArchFunctionConfig::default(),
"http://localhost:8000".to_string(),
);
let json_str = r#"{"name": "test", "value": 123}"#; let json_str = r#"{"name": "test", "value": 123}"#;
let result = handler.fix_json_string(json_str); let result = handler.fix_json_string(json_str);
assert!(result.is_ok()); assert!(result.is_ok());
@ -1392,7 +1493,11 @@ mod tests {
#[test] #[test]
fn test_fix_json_string_missing_bracket() { fn test_fix_json_string_missing_bracket() {
let handler = ArchFunctionHandler::new("test-model".to_string(), ArchFunctionConfig::default(), "http://localhost:8000".to_string()); let handler = ArchFunctionHandler::new(
"test-model".to_string(),
ArchFunctionConfig::default(),
"http://localhost:8000".to_string(),
);
let json_str = r#"{"name": "test", "value": 123"#; let json_str = r#"{"name": "test", "value": 123"#;
let result = handler.fix_json_string(json_str); let result = handler.fix_json_string(json_str);
assert!(result.is_ok()); assert!(result.is_ok());
@ -1402,8 +1507,13 @@ mod tests {
#[test] #[test]
fn test_parse_model_response_with_tool_calls() { fn test_parse_model_response_with_tool_calls() {
let handler = ArchFunctionHandler::new("test-model".to_string(), ArchFunctionConfig::default(), "http://localhost:8000".to_string()); let handler = ArchFunctionHandler::new(
let content = r#"{"tool_calls": [{"name": "get_weather", "arguments": {"location": "NYC"}}]}"#; "test-model".to_string(),
ArchFunctionConfig::default(),
"http://localhost:8000".to_string(),
);
let content =
r#"{"tool_calls": [{"name": "get_weather", "arguments": {"location": "NYC"}}]}"#;
let result = handler.parse_model_response(content); let result = handler.parse_model_response(content);
assert!(result.is_valid); assert!(result.is_valid);
@ -1413,8 +1523,13 @@ mod tests {
#[test] #[test]
fn test_parse_model_response_with_clarification() { fn test_parse_model_response_with_clarification() {
let handler = ArchFunctionHandler::new("test-model".to_string(), ArchFunctionConfig::default(), "http://localhost:8000".to_string()); let handler = ArchFunctionHandler::new(
let content = r#"{"required_functions": ["get_weather"], "clarification": "What location?"}"#; "test-model".to_string(),
ArchFunctionConfig::default(),
"http://localhost:8000".to_string(),
);
let content =
r#"{"required_functions": ["get_weather"], "clarification": "What location?"}"#;
let result = handler.parse_model_response(content); let result = handler.parse_model_response(content);
assert!(result.is_valid); assert!(result.is_valid);
@ -1424,7 +1539,11 @@ mod tests {
#[test] #[test]
fn test_convert_data_type_int_to_float() { fn test_convert_data_type_int_to_float() {
let handler = ArchFunctionHandler::new("test-model".to_string(), ArchFunctionConfig::default(), "http://localhost:8000".to_string()); let handler = ArchFunctionHandler::new(
"test-model".to_string(),
ArchFunctionConfig::default(),
"http://localhost:8000".to_string(),
);
let value = json!(42); let value = json!(42);
let result = handler.convert_data_type(&value, "float"); let result = handler.convert_data_type(&value, "float");
assert!(result.is_ok()); assert!(result.is_ok());
@ -1504,13 +1623,12 @@ pub fn check_threshold(
} }
/// Checks if a parameter is required in the function description /// Checks if a parameter is required in the function description
pub fn is_parameter_required( pub fn is_parameter_required(function_description: &Value, parameter_name: &str) -> bool {
function_description: &Value,
parameter_name: &str,
) -> bool {
if let Some(required) = function_description.get("required") { if let Some(required) = function_description.get("required") {
if let Some(required_arr) = required.as_array() { if let Some(required_arr) = required.as_array() {
return required_arr.iter().any(|v| v.as_str() == Some(parameter_name)); return required_arr
.iter()
.any(|v| v.as_str() == Some(parameter_name));
} }
} }
false false
@ -1559,12 +1677,7 @@ impl HallucinationState {
pub fn new(functions: &[Tool]) -> Self { pub fn new(functions: &[Tool]) -> Self {
let function_properties: HashMap<String, Value> = functions let function_properties: HashMap<String, Value> = functions
.iter() .iter()
.map(|tool| { .map(|tool| (tool.function.name.clone(), tool.function.parameters.clone()))
(
tool.function.name.clone(),
tool.function.parameters.clone(),
)
})
.collect(); .collect();
Self { Self {
@ -1620,7 +1733,10 @@ impl HallucinationState {
// Function name extraction logic // Function name extraction logic
if self.state.as_deref() == Some("function_name") { if self.state.as_deref() == Some("function_name") {
if !FUNC_NAME_END_TOKEN.iter().any(|&t| self.tokens.last().map_or(false, |tok| tok == t)) { if !FUNC_NAME_END_TOKEN
.iter()
.any(|&t| self.tokens.last().is_some_and(|tok| tok == t))
{
self.mask.push(MaskToken::FunctionName); self.mask.push(MaskToken::FunctionName);
} else { } else {
self.state = None; self.state = None;
@ -1629,34 +1745,51 @@ impl HallucinationState {
} }
// Check for function name start // Check for function name start
if FUNC_NAME_START_PATTERN.iter().any(|&p| content.ends_with(p)) { if FUNC_NAME_START_PATTERN
.iter()
.any(|&p| content.ends_with(p))
{
self.state = Some("function_name".to_string()); self.state = Some("function_name".to_string());
} }
// Parameter name extraction logic // Parameter name extraction logic
if self.state.as_deref() == Some("parameter_name") if self.state.as_deref() == Some("parameter_name")
&& !PARAMETER_NAME_END_TOKENS.iter().any(|&t| content.ends_with(t)) { && !PARAMETER_NAME_END_TOKENS
.iter()
.any(|&t| content.ends_with(t))
{
self.mask.push(MaskToken::ParameterName); self.mask.push(MaskToken::ParameterName);
} else if self.state.as_deref() == Some("parameter_name") } else if self.state.as_deref() == Some("parameter_name")
&& PARAMETER_NAME_END_TOKENS.iter().any(|&t| content.ends_with(t)) { && PARAMETER_NAME_END_TOKENS
.iter()
.any(|&t| content.ends_with(t))
{
self.state = None; self.state = None;
self.parameter_name_done = true; self.parameter_name_done = true;
self.get_parameter_name(); self.get_parameter_name();
} else if self.parameter_name_done } else if self.parameter_name_done
&& !self.open_bracket && !self.open_bracket
&& PARAMETER_NAME_START_PATTERN.iter().any(|&p| content.ends_with(p)) { && PARAMETER_NAME_START_PATTERN
.iter()
.any(|&p| content.ends_with(p))
{
self.state = Some("parameter_name".to_string()); self.state = Some("parameter_name".to_string());
} }
// First parameter value start // First parameter value start
if FIRST_PARAM_NAME_START_PATTERN.iter().any(|&p| content.ends_with(p)) { if FIRST_PARAM_NAME_START_PATTERN
.iter()
.any(|&p| content.ends_with(p))
{
self.state = Some("parameter_name".to_string()); self.state = Some("parameter_name".to_string());
} }
// Parameter value extraction logic // Parameter value extraction logic
if self.state.as_deref() == Some("parameter_value") if self.state.as_deref() == Some("parameter_value")
&& !PARAMETER_VALUE_END_TOKEN.iter().any(|&t| content.ends_with(t)) { && !PARAMETER_VALUE_END_TOKEN
.iter()
.any(|&t| content.ends_with(t))
{
// Check for brackets // Check for brackets
if let Some(last_token) = self.tokens.last() { if let Some(last_token) = self.tokens.last() {
let open_brackets: Vec<char> = last_token let open_brackets: Vec<char> = last_token
@ -1694,8 +1827,11 @@ impl HallucinationState {
&& self.mask[self.mask.len() - 2] != MaskToken::ParameterValue && self.mask[self.mask.len() - 2] != MaskToken::ParameterValue
&& !self.parameter_name.is_empty() && !self.parameter_name.is_empty()
{ {
let last_param = self.parameter_name[self.parameter_name.len() - 1].clone(); let last_param =
if let Some(func_props) = self.function_properties.get(&self.function_name) { self.parameter_name[self.parameter_name.len() - 1].clone();
if let Some(func_props) =
self.function_properties.get(&self.function_name)
{
if is_parameter_required(func_props, &last_param) if is_parameter_required(func_props, &last_param)
&& !is_parameter_property(func_props, &last_param, "enum") && !is_parameter_property(func_props, &last_param, "enum")
&& !self.check_parameter_name.contains_key(&last_param) && !self.check_parameter_name.contains_key(&last_param)
@ -1718,10 +1854,16 @@ impl HallucinationState {
} }
} else if self.state.as_deref() == Some("parameter_value") } else if self.state.as_deref() == Some("parameter_value")
&& !self.open_bracket && !self.open_bracket
&& PARAMETER_VALUE_END_TOKEN.iter().any(|&t| content.ends_with(t)) { && PARAMETER_VALUE_END_TOKEN
.iter()
.any(|&t| content.ends_with(t))
{
self.state = None; self.state = None;
} else if self.parameter_name_done } else if self.parameter_name_done
&& PARAMETER_VALUE_START_PATTERN.iter().any(|&p| content.ends_with(p)) { && PARAMETER_VALUE_START_PATTERN
.iter()
.any(|&p| content.ends_with(p))
{
self.state = Some("parameter_value".to_string()); self.state = Some("parameter_value".to_string());
} }
@ -1848,18 +1990,18 @@ mod hallucination_tests {
let handler = ArchFunctionHandler::new( let handler = ArchFunctionHandler::new(
"test-model".to_string(), "test-model".to_string(),
ArchFunctionConfig::default(), ArchFunctionConfig::default(),
"http://localhost:8000".to_string() "http://localhost:8000".to_string(),
); );
// Test integer types // Test integer types
assert!(handler.check_value_type(&json!(42), "integer")); assert!(handler.check_value_type(&json!(42), "integer"));
assert!(handler.check_value_type(&json!(42), "int")); assert!(handler.check_value_type(&json!(42), "int"));
assert!(!handler.check_value_type(&json!(3.14), "integer")); assert!(!handler.check_value_type(&json!(3.15), "integer"));
// Test number types (accepts both int and float) // Test number types (accepts both int and float)
assert!(handler.check_value_type(&json!(3.14), "number")); assert!(handler.check_value_type(&json!(3.15), "number"));
assert!(handler.check_value_type(&json!(42), "number")); assert!(handler.check_value_type(&json!(42), "number"));
assert!(handler.check_value_type(&json!(3.14), "float")); assert!(handler.check_value_type(&json!(3.15), "float"));
// Test boolean // Test boolean
assert!(handler.check_value_type(&json!(true), "boolean")); assert!(handler.check_value_type(&json!(true), "boolean"));
@ -1890,12 +2032,16 @@ mod hallucination_tests {
let handler = ArchFunctionHandler::new( let handler = ArchFunctionHandler::new(
"test-model".to_string(), "test-model".to_string(),
ArchFunctionConfig::default(), ArchFunctionConfig::default(),
"http://localhost:8000".to_string() "http://localhost:8000".to_string(),
); );
// Test valid type - no conversion needed // Test valid type - no conversion needed
assert!(handler.validate_or_convert_parameter(&json!(42), "integer").unwrap()); assert!(handler
assert!(handler.validate_or_convert_parameter(&json!("hello"), "string").unwrap()); .validate_or_convert_parameter(&json!(42), "integer")
.unwrap());
assert!(handler
.validate_or_convert_parameter(&json!("hello"), "string")
.unwrap());
// Test integer to float conversion (convert_data_type supports this) // Test integer to float conversion (convert_data_type supports this)
let result = handler.validate_or_convert_parameter(&json!(42), "float"); let result = handler.validate_or_convert_parameter(&json!(42), "float");
@ -1910,8 +2056,12 @@ mod hallucination_tests {
assert!(!result.unwrap()); assert!(!result.unwrap());
// Test number accepting both int and float // Test number accepting both int and float
assert!(handler.validate_or_convert_parameter(&json!(42), "number").unwrap()); assert!(handler
assert!(handler.validate_or_convert_parameter(&json!(3.14), "number").unwrap()); .validate_or_convert_parameter(&json!(42), "number")
.unwrap());
assert!(handler
.validate_or_convert_parameter(&json!(3.15), "number")
.unwrap());
} }
#[test] #[test]

View file

@ -14,7 +14,7 @@ use crate::router::plano_orchestrator::OrchestratorService;
/// 2. PipelineProcessor - executes the agent pipeline /// 2. PipelineProcessor - executes the agent pipeline
/// 3. ResponseHandler - handles response streaming /// 3. ResponseHandler - handles response streaming
#[cfg(test)] #[cfg(test)]
mod integration_tests { mod tests {
use super::*; use super::*;
use common::configuration::{Agent, AgentFilterChain, Listener}; use common::configuration::{Agent, AgentFilterChain, Listener};
@ -62,7 +62,10 @@ mod integration_tests {
let agent_pipeline = AgentFilterChain { let agent_pipeline = AgentFilterChain {
id: "terminal-agent".to_string(), id: "terminal-agent".to_string(),
filter_chain: Some(vec!["filter-agent".to_string(), "terminal-agent".to_string()]), filter_chain: Some(vec![
"filter-agent".to_string(),
"terminal-agent".to_string(),
]),
description: Some("Test pipeline".to_string()), description: Some("Test pipeline".to_string()),
default: Some(true), default: Some(true),
}; };

View file

@ -2,48 +2,48 @@ use serde::{Deserialize, Serialize};
use std::collections::HashMap; use std::collections::HashMap;
pub const JSON_RPC_VERSION: &str = "2.0"; pub const JSON_RPC_VERSION: &str = "2.0";
pub const TOOL_CALL_METHOD : &str = "tools/call"; pub const TOOL_CALL_METHOD: &str = "tools/call";
pub const MCP_INITIALIZE: &str = "initialize"; pub const MCP_INITIALIZE: &str = "initialize";
pub const MCP_INITIALIZE_NOTIFICATION: &str = "notifications/initialized"; pub const MCP_INITIALIZE_NOTIFICATION: &str = "notifications/initialized";
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)] #[serde(untagged)]
pub enum JsonRpcId { pub enum JsonRpcId {
String(String), String(String),
Number(u64), Number(u64),
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JsonRpcRequest { pub struct JsonRpcRequest {
pub jsonrpc: String, pub jsonrpc: String,
pub id: JsonRpcId, pub id: JsonRpcId,
pub method: String, pub method: String,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub params: Option<HashMap<String, serde_json::Value>>, pub params: Option<HashMap<String, serde_json::Value>>,
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JsonRpcNotification { pub struct JsonRpcNotification {
pub jsonrpc: String, pub jsonrpc: String,
pub method: String, pub method: String,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub params: Option<HashMap<String, serde_json::Value>>, pub params: Option<HashMap<String, serde_json::Value>>,
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JsonRpcError { pub struct JsonRpcError {
pub code: i32, pub code: i32,
pub message: String, pub message: String,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub data: Option<serde_json::Value>, pub data: Option<serde_json::Value>,
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JsonRpcResponse { pub struct JsonRpcResponse {
pub jsonrpc: String, pub jsonrpc: String,
pub id: JsonRpcId, pub id: JsonRpcId,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub result: Option<HashMap<String, serde_json::Value>>, pub result: Option<HashMap<String, serde_json::Value>>,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub error: Option<JsonRpcError>, pub error: Option<JsonRpcError>,
} }

View file

@ -1,6 +1,8 @@
use bytes::Bytes; use bytes::Bytes;
use common::configuration::{LlmProvider, ModelAlias}; use common::configuration::{LlmProvider, ModelAlias};
use common::consts::{ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER, REQUEST_ID_HEADER, TRACE_PARENT_HEADER}; use common::consts::{
ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER, REQUEST_ID_HEADER, TRACE_PARENT_HEADER,
};
use common::traces::TraceCollector; use common::traces::TraceCollector;
use hermesllm::apis::openai_responses::InputParam; use hermesllm::apis::openai_responses::InputParam;
use hermesllm::clients::{SupportedAPIsFromClient, SupportedUpstreamAPIs}; use hermesllm::clients::{SupportedAPIsFromClient, SupportedUpstreamAPIs};
@ -14,13 +16,14 @@ use std::sync::Arc;
use tokio::sync::RwLock; use tokio::sync::RwLock;
use tracing::{debug, info, warn}; use tracing::{debug, info, warn};
use crate::router::llm_router::RouterService;
use crate::handlers::utils::{create_streaming_response, ObservableStreamProcessor, truncate_message};
use crate::handlers::router_chat::router_chat_get_upstream_model; use crate::handlers::router_chat::router_chat_get_upstream_model;
use crate::handlers::utils::{
create_streaming_response, truncate_message, ObservableStreamProcessor,
};
use crate::router::llm_router::RouterService;
use crate::state::response_state_processor::ResponsesStateProcessor; use crate::state::response_state_processor::ResponsesStateProcessor;
use crate::state::{ use crate::state::{
StateStorage, StateStorageError, extract_input_items, retrieve_and_combine_input, StateStorage, StateStorageError,
extract_input_items, retrieve_and_combine_input
}; };
use crate::tracing::operation_component; use crate::tracing::operation_component;
@ -39,7 +42,6 @@ pub async fn llm_chat(
trace_collector: Arc<TraceCollector>, trace_collector: Arc<TraceCollector>,
state_storage: Option<Arc<dyn StateStorage>>, state_storage: Option<Arc<dyn StateStorage>>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> { ) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
let request_path = request.uri().path().to_string(); let request_path = request.uri().path().to_string();
let request_headers = request.headers().clone(); let request_headers = request.headers().clone();
let request_id = request_headers let request_id = request_headers
@ -74,8 +76,14 @@ pub async fn llm_chat(
)) { )) {
Ok(request) => request, Ok(request) => request,
Err(err) => { Err(err) => {
warn!("[PLANO_REQ_ID:{}] | FAILURE | Failed to parse request as ProviderRequestType: {}", request_id, err); warn!(
let err_msg = format!("[PLANO_REQ_ID:{}] | FAILURE | Failed to parse request: {}", request_id, err); "[PLANO_REQ_ID:{}] | FAILURE | Failed to parse request as ProviderRequestType: {}",
request_id, err
);
let err_msg = format!(
"[PLANO_REQ_ID:{}] | FAILURE | Failed to parse request: {}",
request_id, err
);
let mut bad_request = Response::new(full(err_msg)); let mut bad_request = Response::new(full(err_msg));
*bad_request.status_mut() = StatusCode::BAD_REQUEST; *bad_request.status_mut() = StatusCode::BAD_REQUEST;
return Ok(bad_request); return Ok(bad_request);
@ -85,7 +93,10 @@ pub async fn llm_chat(
// === v1/responses state management: Extract input items early === // === v1/responses state management: Extract input items early ===
let mut original_input_items = Vec::new(); let mut original_input_items = Vec::new();
let client_api = SupportedAPIsFromClient::from_endpoint(request_path.as_str()); let client_api = SupportedAPIsFromClient::from_endpoint(request_path.as_str());
let is_responses_api_client = matches!(client_api, Some(SupportedAPIsFromClient::OpenAIResponsesAPI(_))); let is_responses_api_client = matches!(
client_api,
Some(SupportedAPIsFromClient::OpenAIResponsesAPI(_))
);
// Model alias resolution: update model field in client_request immediately // Model alias resolution: update model field in client_request immediately
// This ensures all downstream objects use the resolved model // This ensures all downstream objects use the resolved model
@ -96,20 +107,28 @@ pub async fn llm_chat(
// Extract tool names and user message preview for span attributes // Extract tool names and user message preview for span attributes
let tool_names = client_request.get_tool_names(); let tool_names = client_request.get_tool_names();
let user_message_preview = client_request.get_recent_user_message() let user_message_preview = client_request
.get_recent_user_message()
.map(|msg| truncate_message(&msg, 50)); .map(|msg| truncate_message(&msg, 50));
client_request.set_model(resolved_model.clone()); client_request.set_model(resolved_model.clone());
if client_request.remove_metadata_key("archgw_preference_config") { if client_request.remove_metadata_key("archgw_preference_config") {
debug!("[PLANO_REQ_ID:{}] Removed archgw_preference_config from metadata", request_id); debug!(
"[PLANO_REQ_ID:{}] Removed archgw_preference_config from metadata",
request_id
);
} }
// === v1/responses state management: Determine upstream API and combine input if needed === // === v1/responses state management: Determine upstream API and combine input if needed ===
// Do this BEFORE routing since routing consumes the request // Do this BEFORE routing since routing consumes the request
// Only process state if state_storage is configured // Only process state if state_storage is configured
let mut should_manage_state = false; let mut should_manage_state = false;
if is_responses_api_client && state_storage.is_some() { if is_responses_api_client {
if let ProviderRequestType::ResponsesAPIRequest(ref mut responses_req) = client_request { if let (
ProviderRequestType::ResponsesAPIRequest(ref mut responses_req),
Some(ref state_store),
) = (&mut client_request, &state_storage)
{
// Extract original input once // Extract original input once
original_input_items = extract_input_items(&responses_req.input); original_input_items = extract_input_items(&responses_req.input);
@ -120,18 +139,22 @@ pub async fn llm_chat(
&request_path, &request_path,
&resolved_model, &resolved_model,
is_streaming_request, is_streaming_request,
).await; )
.await;
let upstream_api = SupportedUpstreamAPIs::from_endpoint(&upstream_path); let upstream_api = SupportedUpstreamAPIs::from_endpoint(&upstream_path);
// Only manage state if upstream is NOT OpenAIResponsesAPI (needs translation) // Only manage state if upstream is NOT OpenAIResponsesAPI (needs translation)
should_manage_state = !matches!(upstream_api, Some(SupportedUpstreamAPIs::OpenAIResponsesAPI(_))); should_manage_state = !matches!(
upstream_api,
Some(SupportedUpstreamAPIs::OpenAIResponsesAPI(_))
);
if should_manage_state { if should_manage_state {
// Retrieve and combine conversation history if previous_response_id exists // Retrieve and combine conversation history if previous_response_id exists
if let Some(ref prev_resp_id) = responses_req.previous_response_id { if let Some(ref prev_resp_id) = responses_req.previous_response_id {
match retrieve_and_combine_input( match retrieve_and_combine_input(
state_storage.as_ref().unwrap().clone(), state_store.clone(),
prev_resp_id, prev_resp_id,
original_input_items, // Pass ownership instead of cloning original_input_items, // Pass ownership instead of cloning
) )
@ -166,7 +189,10 @@ pub async fn llm_chat(
} }
} }
} else { } else {
debug!("[PLANO_REQ_ID:{}] | BRIGHT_STAFF | Upstream supports ResponsesAPI natively.", request_id); debug!(
"[PLANO_REQ_ID:{}] | BRIGHT_STAFF | Upstream supports ResponsesAPI natively.",
request_id
);
} }
} }
} }
@ -177,7 +203,7 @@ pub async fn llm_chat(
// Determine routing using the dedicated router_chat module // Determine routing using the dedicated router_chat module
let routing_result = match router_chat_get_upstream_model( let routing_result = match router_chat_get_upstream_model(
router_service, router_service,
client_request, // Pass the original request - router_chat will convert it client_request, // Pass the original request - router_chat will convert it
&request_headers, &request_headers,
trace_collector.clone(), trace_collector.clone(),
&traceparent, &traceparent,
@ -257,7 +283,8 @@ pub async fn llm_chat(
user_message_preview, user_message_preview,
temperature, temperature,
&llm_providers, &llm_providers,
).await; )
.await;
// Create base processor for metrics and tracing // Create base processor for metrics and tracing
let base_processor = ObservableStreamProcessor::new( let base_processor = ObservableStreamProcessor::new(
@ -269,7 +296,11 @@ pub async fn llm_chat(
// === v1/responses state management: Wrap with ResponsesStateProcessor === // === v1/responses state management: Wrap with ResponsesStateProcessor ===
// Only wrap if we need to manage state (client is ResponsesAPI AND upstream is NOT ResponsesAPI AND state_storage is configured) // Only wrap if we need to manage state (client is ResponsesAPI AND upstream is NOT ResponsesAPI AND state_storage is configured)
let streaming_response = if should_manage_state && !original_input_items.is_empty() && state_storage.is_some() { let streaming_response = if let (true, false, Some(state_store)) = (
should_manage_state,
original_input_items.is_empty(),
state_storage,
) {
// Extract Content-Encoding header to handle decompression for state parsing // Extract Content-Encoding header to handle decompression for state parsing
let content_encoding = response_headers let content_encoding = response_headers
.get("content-encoding") .get("content-encoding")
@ -279,7 +310,7 @@ pub async fn llm_chat(
// Wrap with state management processor to store state after response completes // Wrap with state management processor to store state after response completes
let state_processor = ResponsesStateProcessor::new( let state_processor = ResponsesStateProcessor::new(
base_processor, base_processor,
state_storage.unwrap(), state_store,
original_input_items, original_input_items,
resolved_model.clone(), resolved_model.clone(),
model_name.clone(), model_name.clone(),
@ -324,6 +355,7 @@ fn resolve_model_alias(
} }
/// Builds the LLM span with all required and optional attributes. /// Builds the LLM span with all required and optional attributes.
#[allow(clippy::too_many_arguments)]
async fn build_llm_span( async fn build_llm_span(
traceparent: &str, traceparent: &str,
request_path: &str, request_path: &str,
@ -337,8 +369,8 @@ async fn build_llm_span(
temperature: Option<f32>, temperature: Option<f32>,
llm_providers: &Arc<RwLock<Vec<LlmProvider>>>, llm_providers: &Arc<RwLock<Vec<LlmProvider>>>,
) -> common::traces::Span { ) -> common::traces::Span {
use common::traces::{SpanBuilder, SpanKind, parse_traceparent};
use crate::tracing::{http, llm, OperationNameBuilder}; use crate::tracing::{http, llm, OperationNameBuilder};
use common::traces::{parse_traceparent, SpanBuilder, SpanKind};
// Calculate the upstream path based on provider configuration // Calculate the upstream path based on provider configuration
let upstream_path = get_upstream_path( let upstream_path = get_upstream_path(
@ -347,13 +379,14 @@ async fn build_llm_span(
request_path, request_path,
resolved_model, resolved_model,
is_streaming, is_streaming,
).await; )
.await;
// Build operation name showing path transformation if different // Build operation name showing path transformation if different
let operation_name = if request_path != upstream_path { let operation_name = if request_path != upstream_path {
OperationNameBuilder::new() OperationNameBuilder::new()
.with_method("POST") .with_method("POST")
.with_path(&format!("{} >> {}", request_path, upstream_path)) .with_path(format!("{} >> {}", request_path, upstream_path))
.with_target(resolved_model) .with_target(resolved_model)
.build() .build()
} else { } else {
@ -388,7 +421,8 @@ async fn build_llm_span(
} }
if let Some(tools) = tool_names { if let Some(tools) = tool_names {
let formatted_tools = tools.iter() let formatted_tools = tools
.iter()
.map(|name| format!("{}(...)", name)) .map(|name| format!("{}(...)", name))
.collect::<Vec<_>>() .collect::<Vec<_>>()
.join("\n"); .join("\n");
@ -436,8 +470,7 @@ async fn get_provider_info(
// First, try to find by model name or provider name // First, try to find by model name or provider name
let provider = providers_lock.iter().find(|p| { let provider = providers_lock.iter().find(|p| {
p.model.as_ref().map(|m| m == model_name).unwrap_or(false) p.model.as_ref().map(|m| m == model_name).unwrap_or(false) || p.name == model_name
|| p.name == model_name
}); });
if let Some(provider) = provider { if let Some(provider) = provider {
@ -446,9 +479,7 @@ async fn get_provider_info(
return (provider_id, prefix); return (provider_id, prefix);
} }
let default_provider = providers_lock.iter().find(|p| { let default_provider = providers_lock.iter().find(|p| p.default.unwrap_or(false));
p.default.unwrap_or(false)
});
if let Some(provider) = default_provider { if let Some(provider) = default_provider {
let provider_id = provider.provider_interface.to_provider_id(); let provider_id = provider.provider_interface.to_provider_id();

View file

@ -1,13 +1,13 @@
pub mod agent_chat_completions; pub mod agent_chat_completions;
pub mod agent_selector; pub mod agent_selector;
pub mod llm;
pub mod router_chat;
pub mod models;
pub mod function_calling; pub mod function_calling;
pub mod jsonrpc;
pub mod llm;
pub mod models;
pub mod pipeline_processor; pub mod pipeline_processor;
pub mod response_handler; pub mod response_handler;
pub mod router_chat;
pub mod utils; pub mod utils;
pub mod jsonrpc;
#[cfg(test)] #[cfg(test)]
mod integration_tests; mod integration_tests;

View file

@ -82,6 +82,7 @@ impl PipelineProcessor {
} }
/// Record a span for filter execution /// Record a span for filter execution
#[allow(clippy::too_many_arguments)]
fn record_filter_span( fn record_filter_span(
&self, &self,
collector: &std::sync::Arc<common::traces::TraceCollector>, collector: &std::sync::Arc<common::traces::TraceCollector>,
@ -132,6 +133,7 @@ impl PipelineProcessor {
} }
/// Record a span for MCP protocol interactions /// Record a span for MCP protocol interactions
#[allow(clippy::too_many_arguments)]
fn record_agent_filter_span( fn record_agent_filter_span(
&self, &self,
collector: &std::sync::Arc<common::traces::TraceCollector>, collector: &std::sync::Arc<common::traces::TraceCollector>,
@ -156,12 +158,12 @@ impl PipelineProcessor {
.build(); .build();
let mut span_builder = SpanBuilder::new(&operation_name) let mut span_builder = SpanBuilder::new(&operation_name)
.with_span_id(span_id.unwrap_or_else(|| generate_random_span_id())) .with_span_id(span_id.unwrap_or_else(generate_random_span_id))
.with_kind(SpanKind::Client) .with_kind(SpanKind::Client)
.with_start_time(start_time) .with_start_time(start_time)
.with_end_time(end_time) .with_end_time(end_time)
.with_attribute(http::METHOD, "POST") .with_attribute(http::METHOD, "POST")
.with_attribute(http::TARGET, &format!("/mcp ({})", operation.to_string())) .with_attribute(http::TARGET, format!("/mcp ({})", operation))
.with_attribute("mcp.operation", operation.to_string()) .with_attribute("mcp.operation", operation.to_string())
.with_attribute("mcp.agent_id", agent_id.to_string()) .with_attribute("mcp.agent_id", agent_id.to_string())
.with_attribute( .with_attribute(
@ -188,6 +190,7 @@ impl PipelineProcessor {
} }
/// Process the filter chain of agents (all except the terminal agent) /// Process the filter chain of agents (all except the terminal agent)
#[allow(clippy::too_many_arguments)]
pub async fn process_filter_chain( pub async fn process_filter_chain(
&mut self, &mut self,
chat_history: &[Message], chat_history: &[Message],
@ -1023,7 +1026,7 @@ mod tests {
} }
}); });
let sse_body = format!("event: message\ndata: {}\n\n", rpc_body.to_string()); let sse_body = format!("event: message\ndata: {}\n\n", rpc_body);
let mut server = Server::new_async().await; let mut server = Server::new_async().await;
let _m = server let _m = server
@ -1061,10 +1064,10 @@ mod tests {
.await; .await;
match result { match result {
Err(PipelineError::ClientError { status, body, .. }) => { Err(PipelineError::ClientError { status, body, .. }) => {
assert_eq!(status, 400); assert_eq!(status, 400);
assert_eq!(body, "bad tool call"); assert_eq!(body, "bad tool call");
} }
_ => panic!("Expected client error when isError flag is set"), _ => panic!("Expected client error when isError flag is set"),
} }
} }

View file

@ -133,9 +133,7 @@ impl ResponseHandler {
let response_headers = llm_response.headers(); let response_headers = llm_response.headers();
let is_sse_streaming = response_headers let is_sse_streaming = response_headers
.get(hyper::header::CONTENT_TYPE) .get(hyper::header::CONTENT_TYPE)
.map_or(false, |v| { .is_some_and(|v| v.to_str().unwrap_or("").contains("text/event-stream"));
v.to_str().unwrap_or("").contains("text/event-stream")
});
let response_bytes = llm_response let response_bytes = llm_response
.bytes() .bytes()
@ -164,7 +162,7 @@ impl ResponseHandler {
match transformed_event.provider_response() { match transformed_event.provider_response() {
Ok(provider_response) => { Ok(provider_response) => {
if let Some(content) = provider_response.content_delta() { if let Some(content) = provider_response.content_delta() {
accumulated_text.push_str(&content); accumulated_text.push_str(content);
} else { } else {
info!("No content delta in provider response"); info!("No content delta in provider response");
} }
@ -174,7 +172,7 @@ impl ResponseHandler {
} }
} }
} }
return Ok(accumulated_text); Ok(accumulated_text)
} else { } else {
// If not SSE, treat as regular text response // If not SSE, treat as regular text response
let response_text = String::from_utf8(response_bytes.to_vec()).map_err(|e| { let response_text = String::from_utf8(response_bytes.to_vec()).map_err(|e| {

View file

@ -1,6 +1,6 @@
use common::configuration::ModelUsagePreference; use common::configuration::ModelUsagePreference;
use common::consts::{REQUEST_ID_HEADER}; use common::consts::REQUEST_ID_HEADER;
use common::traces::{TraceCollector, SpanKind, SpanBuilder, parse_traceparent}; use common::traces::{parse_traceparent, SpanBuilder, SpanKind, TraceCollector};
use hermesllm::clients::endpoints::SupportedUpstreamAPIs; use hermesllm::clients::endpoints::SupportedUpstreamAPIs;
use hermesllm::{ProviderRequest, ProviderRequestType}; use hermesllm::{ProviderRequest, ProviderRequestType};
use hyper::StatusCode; use hyper::StatusCode;
@ -9,10 +9,10 @@ use std::sync::Arc;
use tracing::{debug, info, warn}; use tracing::{debug, info, warn};
use crate::router::llm_router::RouterService; use crate::router::llm_router::RouterService;
use crate::tracing::{OperationNameBuilder, operation_component, http, routing}; use crate::tracing::{http, operation_component, routing, OperationNameBuilder};
pub struct RoutingResult { pub struct RoutingResult {
pub model_name: String pub model_name: String,
} }
pub struct RoutingError { pub struct RoutingError {
@ -24,7 +24,7 @@ impl RoutingError {
pub fn internal_error(message: String) -> Self { pub fn internal_error(message: String) -> Self {
Self { Self {
message, message,
status_code: StatusCode::INTERNAL_SERVER_ERROR status_code: StatusCode::INTERNAL_SERVER_ERROR,
} }
} }
} }
@ -52,9 +52,7 @@ pub async fn router_chat_get_upstream_model(
// Convert to ChatCompletionsRequest for routing (regardless of input type) // Convert to ChatCompletionsRequest for routing (regardless of input type)
let chat_request = match ProviderRequestType::try_from(( let chat_request = match ProviderRequestType::try_from((
client_request, client_request,
&SupportedUpstreamAPIs::OpenAIChatCompletions( &SupportedUpstreamAPIs::OpenAIChatCompletions(hermesllm::apis::OpenAIApi::ChatCompletions),
hermesllm::apis::OpenAIApi::ChatCompletions,
),
)) { )) {
Ok(ProviderRequestType::ChatCompletionsRequest(req)) => req, Ok(ProviderRequestType::ChatCompletionsRequest(req)) => req,
Ok( Ok(
@ -69,7 +67,10 @@ pub async fn router_chat_get_upstream_model(
)); ));
} }
Err(err) => { Err(err) => {
warn!("Failed to convert request to ChatCompletionsRequest: {}", err); warn!(
"Failed to convert request to ChatCompletionsRequest: {}",
err
);
return Err(RoutingError::internal_error(format!( return Err(RoutingError::internal_error(format!(
"Failed to convert request: {}", "Failed to convert request: {}",
err err
@ -151,9 +152,7 @@ pub async fn router_chat_get_upstream_model(
) )
.await; .await;
Ok(RoutingResult { Ok(RoutingResult { model_name })
model_name
})
} }
None => { None => {
// No route determined, use default model from request // No route determined, use default model from request
@ -176,7 +175,7 @@ pub async fn router_chat_get_upstream_model(
.await; .await;
Ok(RoutingResult { Ok(RoutingResult {
model_name: default_model model_name: default_model,
}) })
} }
}, },
@ -194,9 +193,10 @@ pub async fn router_chat_get_upstream_model(
) )
.await; .await;
Err(RoutingError::internal_error( Err(RoutingError::internal_error(format!(
format!("Failed to determine route: {}", err) "Failed to determine route: {}",
)) err
)))
} }
} }
} }
@ -230,7 +230,10 @@ async fn record_routing_span(
.with_end_time(std::time::SystemTime::now()) .with_end_time(std::time::SystemTime::now())
.with_attribute(http::METHOD, "POST") .with_attribute(http::METHOD, "POST")
.with_attribute(http::TARGET, routing_api_path.to_string()) .with_attribute(http::TARGET, routing_api_path.to_string())
.with_attribute(routing::ROUTE_DETERMINATION_MS, start_time.elapsed().as_millis().to_string()); .with_attribute(
routing::ROUTE_DETERMINATION_MS,
start_time.elapsed().as_millis().to_string(),
);
// Only set parent span ID if it exists (not a root span) // Only set parent span ID if it exists (not a root span)
if let Some(parent) = parent_span_id { if let Some(parent) = parent_span_id {

View file

@ -1,5 +1,5 @@
use bytes::Bytes; use bytes::Bytes;
use common::traces::{Span, Attribute, AttributeValue, TraceCollector, Event}; use common::traces::{Attribute, AttributeValue, Event, Span, TraceCollector};
use http_body_util::combinators::BoxBody; use http_body_util::combinators::BoxBody;
use http_body_util::StreamBody; use http_body_util::StreamBody;
use hyper::body::Frame; use hyper::body::Frame;
@ -11,7 +11,7 @@ use tokio_stream::StreamExt;
use tracing::warn; use tracing::warn;
// Import tracing constants // Import tracing constants
use crate::tracing::{llm, error}; use crate::tracing::{error, llm};
/// Trait for processing streaming chunks /// Trait for processing streaming chunks
/// Implementors can inject custom logic during streaming (e.g., hallucination detection, logging) /// Implementors can inject custom logic during streaming (e.g., hallucination detection, logging)
@ -97,7 +97,6 @@ impl StreamProcessor for ObservableStreamProcessor {
}, },
}); });
self.span.attributes.push(Attribute { self.span.attributes.push(Attribute {
key: llm::DURATION_MS.to_string(), key: llm::DURATION_MS.to_string(),
value: AttributeValue { value: AttributeValue {
@ -119,11 +118,9 @@ impl StreamProcessor for ObservableStreamProcessor {
if let Ok(start_time_nanos) = self.span.start_time_unix_nano.parse::<u128>() { if let Ok(start_time_nanos) = self.span.start_time_unix_nano.parse::<u128>() {
// Convert ttft from milliseconds to nanoseconds and add to start time // Convert ttft from milliseconds to nanoseconds and add to start time
let event_timestamp = start_time_nanos + (ttft * 1_000_000); let event_timestamp = start_time_nanos + (ttft * 1_000_000);
let mut event = Event::new(llm::TIME_TO_FIRST_TOKEN_MS.to_string(), event_timestamp); let mut event =
event.add_attribute( Event::new(llm::TIME_TO_FIRST_TOKEN_MS.to_string(), event_timestamp);
llm::TIME_TO_FIRST_TOKEN_MS.to_string(), event.add_attribute(llm::TIME_TO_FIRST_TOKEN_MS.to_string(), ttft.to_string());
ttft.to_string(),
);
// Initialize events vector if needed // Initialize events vector if needed
if self.span.events.is_none() { if self.span.events.is_none() {
@ -137,7 +134,8 @@ impl StreamProcessor for ObservableStreamProcessor {
} }
// Record the finalized span // Record the finalized span
self.collector.record_span(&self.service_name, self.span.clone()); self.collector
.record_span(&self.service_name, self.span.clone());
} }
fn on_error(&mut self, error_msg: &str) { fn on_error(&mut self, error_msg: &str) {
@ -173,7 +171,8 @@ impl StreamProcessor for ObservableStreamProcessor {
}); });
// Record the error span // Record the error span
self.collector.record_span(&self.service_name, self.span.clone()); self.collector
.record_span(&self.service_name, self.span.clone());
} }
} }

View file

@ -4,13 +4,15 @@ use brightstaff::handlers::llm::llm_chat;
use brightstaff::handlers::models::list_models; use brightstaff::handlers::models::list_models;
use brightstaff::router::llm_router::RouterService; use brightstaff::router::llm_router::RouterService;
use brightstaff::router::plano_orchestrator::OrchestratorService; use brightstaff::router::plano_orchestrator::OrchestratorService;
use brightstaff::state::StateStorage;
use brightstaff::state::postgresql::PostgreSQLConversationStorage;
use brightstaff::state::memory::MemoryConversationalStorage; use brightstaff::state::memory::MemoryConversationalStorage;
use brightstaff::state::postgresql::PostgreSQLConversationStorage;
use brightstaff::state::StateStorage;
use brightstaff::utils::tracing::init_tracer; use brightstaff::utils::tracing::init_tracer;
use bytes::Bytes; use bytes::Bytes;
use common::configuration::{Agent, Configuration}; use common::configuration::{Agent, Configuration};
use common::consts::{CHAT_COMPLETIONS_PATH, MESSAGES_PATH, OPENAI_RESPONSES_API_PATH, PLANO_ORCHESTRATOR_MODEL_NAME}; use common::consts::{
CHAT_COMPLETIONS_PATH, MESSAGES_PATH, OPENAI_RESPONSES_API_PATH, PLANO_ORCHESTRATOR_MODEL_NAME,
};
use common::traces::TraceCollector; use common::traces::TraceCollector;
use http_body_util::{combinators::BoxBody, BodyExt, Empty}; use http_body_util::{combinators::BoxBody, BodyExt, Empty};
use hyper::body::Incoming; use hyper::body::Incoming;
@ -105,7 +107,6 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
PLANO_ORCHESTRATOR_MODEL_NAME.to_string(), PLANO_ORCHESTRATOR_MODEL_NAME.to_string(),
)); ));
let model_aliases = Arc::new(arch_config.model_aliases.clone()); let model_aliases = Arc::new(arch_config.model_aliases.clone());
// Initialize trace collector and start background flusher // Initialize trace collector and start background flusher
@ -127,33 +128,33 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
// Configurable via arch_config.yaml state_storage section // Configurable via arch_config.yaml state_storage section
// If not configured, state management is disabled // If not configured, state management is disabled
// Environment variables are substituted by envsubst before config is read // Environment variables are substituted by envsubst before config is read
let state_storage: Option<Arc<dyn StateStorage>> = if let Some(storage_config) = &arch_config.state_storage { let state_storage: Option<Arc<dyn StateStorage>> =
let storage: Arc<dyn StateStorage> = match storage_config.storage_type { if let Some(storage_config) = &arch_config.state_storage {
common::configuration::StateStorageType::Memory => { let storage: Arc<dyn StateStorage> = match storage_config.storage_type {
info!("Initialized conversation state storage: Memory"); common::configuration::StateStorageType::Memory => {
Arc::new(MemoryConversationalStorage::new()) info!("Initialized conversation state storage: Memory");
} Arc::new(MemoryConversationalStorage::new())
common::configuration::StateStorageType::Postgres => { }
let connection_string = storage_config common::configuration::StateStorageType::Postgres => {
.connection_string let connection_string = storage_config
.as_ref() .connection_string
.expect("connection_string is required for postgres state_storage"); .as_ref()
.expect("connection_string is required for postgres state_storage");
debug!("Postgres connection string (full): {}", connection_string); debug!("Postgres connection string (full): {}", connection_string);
info!("Initializing conversation state storage: Postgres"); info!("Initializing conversation state storage: Postgres");
Arc::new( Arc::new(
PostgreSQLConversationStorage::new(connection_string.clone()) PostgreSQLConversationStorage::new(connection_string.clone())
.await .await
.expect("Failed to initialize Postgres state storage"), .expect("Failed to initialize Postgres state storage"),
) )
} }
};
Some(storage)
} else {
info!("No state_storage configured - conversation state management disabled");
None
}; };
Some(storage)
} else {
info!("No state_storage configured - conversation state management disabled");
None
};
loop { loop {
let (stream, _) = listener.accept().await?; let (stream, _) = listener.accept().await?;
@ -208,12 +209,22 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
} }
} }
match (req.method(), path) { match (req.method(), path) {
(&Method::POST, CHAT_COMPLETIONS_PATH | MESSAGES_PATH | OPENAI_RESPONSES_API_PATH) => { (
let fully_qualified_url = &Method::POST,
format!("{}{}", llm_provider_url, path); CHAT_COMPLETIONS_PATH | MESSAGES_PATH | OPENAI_RESPONSES_API_PATH,
llm_chat(req, router_service, fully_qualified_url, model_aliases, llm_providers, trace_collector, state_storage) ) => {
.with_context(parent_cx) let fully_qualified_url = format!("{}{}", llm_provider_url, path);
.await llm_chat(
req,
router_service,
fully_qualified_url,
model_aliases,
llm_providers,
trace_collector,
state_storage,
)
.with_context(parent_cx)
.await
} }
(&Method::POST, "/function_calling") => { (&Method::POST, "/function_calling") => {
let fully_qualified_url = let fully_qualified_url =

View file

@ -2,7 +2,7 @@ use std::collections::HashMap;
use common::configuration::{AgentUsagePreference, OrchestrationPreference}; use common::configuration::{AgentUsagePreference, OrchestrationPreference};
use hermesllm::apis::openai::{ChatCompletionsRequest, Message, MessageContent, Role}; use hermesllm::apis::openai::{ChatCompletionsRequest, Message, MessageContent, Role};
use serde::{Deserialize, Serialize, ser::Serialize as SerializeTrait}; use serde::{ser::Serialize as SerializeTrait, Deserialize, Serialize};
use tracing::{debug, warn}; use tracing::{debug, warn};
use super::orchestrator_model::{OrchestratorModel, OrchestratorModelError}; use super::orchestrator_model::{OrchestratorModel, OrchestratorModelError};
@ -144,7 +144,7 @@ impl OrchestratorModelV1 {
// Format routes: each route as JSON on its own line with standard spacing // Format routes: each route as JSON on its own line with standard spacing
let agent_orchestration_json_str = agent_orchestration_values let agent_orchestration_json_str = agent_orchestration_values
.iter() .iter()
.map(|pref| to_spaced_json(pref)) .map(to_spaced_json)
.collect::<Vec<String>>() .collect::<Vec<String>>()
.join("\n"); .join("\n");
let agent_orchestration_to_model_map: HashMap<String, String> = agent_orchestrations let agent_orchestration_to_model_map: HashMap<String, String> = agent_orchestrations
@ -238,24 +238,26 @@ impl OrchestratorModel for OrchestratorModelV1 {
let selected_conversation_list = selected_messages_list_reversed let selected_conversation_list = selected_messages_list_reversed
.iter() .iter()
.rev() .rev()
.map(|message| { .map(|message| Message {
Message { role: message.role.clone(),
role: message.role.clone(), content: MessageContent::Text(message.content.to_string()),
content: MessageContent::Text(message.content.to_string()), name: None,
name: None, tool_calls: None,
tool_calls: None, tool_call_id: None,
tool_call_id: None,
}
}) })
.collect::<Vec<Message>>(); .collect::<Vec<Message>>();
// Generate the orchestrator request message based on the usage preferences. // Generate the orchestrator request message based on the usage preferences.
// If preferences are passed in request then we use them; // If preferences are passed in request then we use them;
// Otherwise, we use the default orchestration modelpreferences. // Otherwise, we use the default orchestration modelpreferences.
let orchestrator_message = match convert_to_orchestrator_preferences(usage_preferences_from_request) { let orchestrator_message =
Some(prefs) => generate_orchestrator_message(&prefs, &selected_conversation_list), match convert_to_orchestrator_preferences(usage_preferences_from_request) {
None => generate_orchestrator_message(&self.agent_orchestration_json_str, &selected_conversation_list), Some(prefs) => generate_orchestrator_message(&prefs, &selected_conversation_list),
}; None => generate_orchestrator_message(
&self.agent_orchestration_json_str,
&selected_conversation_list,
),
};
ChatCompletionsRequest { ChatCompletionsRequest {
model: self.orchestration_model.clone(), model: self.orchestration_model.clone(),
@ -280,7 +282,8 @@ impl OrchestratorModel for OrchestratorModelV1 {
return Ok(None); return Ok(None);
} }
let orchestrator_resp_fixed = fix_json_response(content); let orchestrator_resp_fixed = fix_json_response(content);
let orchestrator_response: AgentOrchestratorResponse = serde_json::from_str(orchestrator_resp_fixed.as_str())?; let orchestrator_response: AgentOrchestratorResponse =
serde_json::from_str(orchestrator_resp_fixed.as_str())?;
let selected_routes = orchestrator_response.route.unwrap_or_default(); let selected_routes = orchestrator_response.route.unwrap_or_default();
@ -320,7 +323,11 @@ impl OrchestratorModel for OrchestratorModelV1 {
} else { } else {
// If no usage preferences are passed in request then use the default orchestration model preferences // If no usage preferences are passed in request then use the default orchestration model preferences
for selected_route in valid_routes { for selected_route in valid_routes {
if let Some(model) = self.agent_orchestration_to_model_map.get(&selected_route).cloned() { if let Some(model) = self
.agent_orchestration_to_model_map
.get(&selected_route)
.cloned()
{
result.push((selected_route, model)); result.push((selected_route, model));
} else { } else {
warn!( warn!(
@ -375,7 +382,7 @@ fn convert_to_orchestrator_preferences(
// Format routes: each route as JSON on its own line with standard spacing // Format routes: each route as JSON on its own line with standard spacing
let routes_str = orchestration_preferences let routes_str = orchestration_preferences
.iter() .iter()
.map(|pref| to_spaced_json(pref)) .map(to_spaced_json)
.collect::<Vec<String>>() .collect::<Vec<String>>()
.join("\n"); .join("\n");
@ -425,7 +432,10 @@ mod tests {
// CRITICAL: Test that colons inside string values are NOT modified // CRITICAL: Test that colons inside string values are NOT modified
let with_colon = serde_json::json!({"name": "foo:bar", "url": "http://example.com"}); let with_colon = serde_json::json!({"name": "foo:bar", "url": "http://example.com"});
let result = to_spaced_json(&with_colon); let result = to_spaced_json(&with_colon);
assert_eq!(result, r#"{"name": "foo:bar", "url": "http://example.com"}"#); assert_eq!(
result,
r#"{"name": "foo:bar", "url": "http://example.com"}"#
);
// Test empty object and array // Test empty object and array
let empty_obj = serde_json::json!({}); let empty_obj = serde_json::json!({});
@ -446,7 +456,8 @@ mod tests {
}); });
let result = to_spaced_json(&complex); let result = to_spaced_json(&complex);
// Verify URLs with colons are preserved correctly // Verify URLs with colons are preserved correctly
assert!(result.contains(r#""urls": ["https://api.example.com:8080/path", "file:///local/path"]"#)); assert!(result
.contains(r#""urls": ["https://api.example.com:8080/path", "file:///local/path"]"#));
// Verify spacing format // Verify spacing format
assert!(result.contains(r#""type": "object""#)); assert!(result.contains(r#""type": "object""#));
assert!(result.contains(r#""properties": {}"#)); assert!(result.contains(r#""properties": {}"#));
@ -497,10 +508,16 @@ If no routes are needed, return an empty list for `route`.
] ]
} }
"#; "#;
let agent_orchestrations = let agent_orchestrations = serde_json::from_str::<
serde_json::from_str::<HashMap<String, Vec<OrchestrationPreference>>>(orchestrations_str).unwrap(); HashMap<String, Vec<OrchestrationPreference>>,
>(orchestrations_str)
.unwrap();
let orchestration_model = "test-model".to_string(); let orchestration_model = "test-model".to_string();
let orchestrator = OrchestratorModelV1::new(agent_orchestrations, orchestration_model.clone(), usize::MAX); let orchestrator = OrchestratorModelV1::new(
agent_orchestrations,
orchestration_model.clone(),
usize::MAX,
);
let conversation_str = r#" let conversation_str = r#"
[ [
@ -568,7 +585,11 @@ If no routes are needed, return an empty list for `route`.
// Empty orchestrations map - not used when usage_preferences are provided // Empty orchestrations map - not used when usage_preferences are provided
let agent_orchestrations: HashMap<String, Vec<OrchestrationPreference>> = HashMap::new(); let agent_orchestrations: HashMap<String, Vec<OrchestrationPreference>> = HashMap::new();
let orchestration_model = "test-model".to_string(); let orchestration_model = "test-model".to_string();
let orchestrator = OrchestratorModelV1::new(agent_orchestrations, orchestration_model.clone(), usize::MAX); let orchestrator = OrchestratorModelV1::new(
agent_orchestrations,
orchestration_model.clone(),
usize::MAX,
);
let conversation_str = r#" let conversation_str = r#"
[ [
@ -640,10 +661,13 @@ If no routes are needed, return an empty list for `route`.
] ]
} }
"#; "#;
let agent_orchestrations = let agent_orchestrations = serde_json::from_str::<
serde_json::from_str::<HashMap<String, Vec<OrchestrationPreference>>>(orchestrations_str).unwrap(); HashMap<String, Vec<OrchestrationPreference>>,
>(orchestrations_str)
.unwrap();
let orchestration_model = "test-model".to_string(); let orchestration_model = "test-model".to_string();
let orchestrator = OrchestratorModelV1::new(agent_orchestrations, orchestration_model.clone(), 235); let orchestrator =
OrchestratorModelV1::new(agent_orchestrations, orchestration_model.clone(), 235);
let conversation_str = r#" let conversation_str = r#"
[ [
@ -709,11 +733,14 @@ If no routes are needed, return an empty list for `route`.
] ]
} }
"#; "#;
let agent_orchestrations = let agent_orchestrations = serde_json::from_str::<
serde_json::from_str::<HashMap<String, Vec<OrchestrationPreference>>>(orchestrations_str).unwrap(); HashMap<String, Vec<OrchestrationPreference>>,
>(orchestrations_str)
.unwrap();
let orchestration_model = "test-model".to_string(); let orchestration_model = "test-model".to_string();
let orchestrator = OrchestratorModelV1::new(agent_orchestrations, orchestration_model.clone(), 200); let orchestrator =
OrchestratorModelV1::new(agent_orchestrations, orchestration_model.clone(), 200);
let conversation_str = r#" let conversation_str = r#"
[ [
@ -787,10 +814,13 @@ If no routes are needed, return an empty list for `route`.
] ]
} }
"#; "#;
let agent_orchestrations = let agent_orchestrations = serde_json::from_str::<
serde_json::from_str::<HashMap<String, Vec<OrchestrationPreference>>>(orchestrations_str).unwrap(); HashMap<String, Vec<OrchestrationPreference>>,
>(orchestrations_str)
.unwrap();
let orchestration_model = "test-model".to_string(); let orchestration_model = "test-model".to_string();
let orchestrator = OrchestratorModelV1::new(agent_orchestrations, orchestration_model.clone(), 230); let orchestrator =
OrchestratorModelV1::new(agent_orchestrations, orchestration_model.clone(), 230);
let conversation_str = r#" let conversation_str = r#"
[ [
@ -871,10 +901,16 @@ If no routes are needed, return an empty list for `route`.
] ]
} }
"#; "#;
let agent_orchestrations = let agent_orchestrations = serde_json::from_str::<
serde_json::from_str::<HashMap<String, Vec<OrchestrationPreference>>>(orchestrations_str).unwrap(); HashMap<String, Vec<OrchestrationPreference>>,
>(orchestrations_str)
.unwrap();
let orchestration_model = "test-model".to_string(); let orchestration_model = "test-model".to_string();
let orchestrator = OrchestratorModelV1::new(agent_orchestrations, orchestration_model.clone(), usize::MAX); let orchestrator = OrchestratorModelV1::new(
agent_orchestrations,
orchestration_model.clone(),
usize::MAX,
);
let conversation_str = r#" let conversation_str = r#"
[ [
@ -957,10 +993,16 @@ If no routes are needed, return an empty list for `route`.
] ]
} }
"#; "#;
let agent_orchestrations = let agent_orchestrations = serde_json::from_str::<
serde_json::from_str::<HashMap<String, Vec<OrchestrationPreference>>>(orchestrations_str).unwrap(); HashMap<String, Vec<OrchestrationPreference>>,
>(orchestrations_str)
.unwrap();
let orchestration_model = "test-model".to_string(); let orchestration_model = "test-model".to_string();
let orchestrator = OrchestratorModelV1::new(agent_orchestrations, orchestration_model.clone(), usize::MAX); let orchestrator = OrchestratorModelV1::new(
agent_orchestrations,
orchestration_model.clone(),
usize::MAX,
);
let conversation_str = r#" let conversation_str = r#"
[ [
@ -1034,10 +1076,13 @@ If no routes are needed, return an empty list for `route`.
] ]
} }
"#; "#;
let agent_orchestrations = let agent_orchestrations = serde_json::from_str::<
serde_json::from_str::<HashMap<String, Vec<OrchestrationPreference>>>(orchestrations_str).unwrap(); HashMap<String, Vec<OrchestrationPreference>>,
>(orchestrations_str)
.unwrap();
let orchestrator = OrchestratorModelV1::new(agent_orchestrations, "test-model".to_string(), 2000); let orchestrator =
OrchestratorModelV1::new(agent_orchestrations, "test-model".to_string(), 2000);
// Case 1: Valid JSON with single route in array // Case 1: Valid JSON with single route in array
let input = r#"{"route": ["Image generation"]}"#; let input = r#"{"route": ["Image generation"]}"#;

View file

@ -34,10 +34,7 @@ pub enum OrchestrationError {
pub type Result<T> = std::result::Result<T, OrchestrationError>; pub type Result<T> = std::result::Result<T, OrchestrationError>;
impl OrchestratorService { impl OrchestratorService {
pub fn new( pub fn new(orchestrator_url: String, orchestration_model_name: String) -> Self {
orchestrator_url: String,
orchestration_model_name: String,
) -> Self {
// Empty agent orchestrations - will be provided via usage_preferences in requests // Empty agent orchestrations - will be provided via usage_preferences in requests
let agent_orchestrations: HashMap<String, Vec<OrchestrationPreference>> = HashMap::new(); let agent_orchestrations: HashMap<String, Vec<OrchestrationPreference>> = HashMap::new();

View file

@ -85,13 +85,19 @@ impl StateStorage for MemoryConversationalStorage {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use hermesllm::apis::openai_responses::{InputItem, InputMessage, MessageRole, InputContent, MessageContent}; use hermesllm::apis::openai_responses::{
InputContent, InputItem, InputMessage, MessageContent, MessageRole,
};
fn create_test_state(response_id: &str, num_messages: usize) -> OpenAIConversationState { fn create_test_state(response_id: &str, num_messages: usize) -> OpenAIConversationState {
let mut input_items = Vec::new(); let mut input_items = Vec::new();
for i in 0..num_messages { for i in 0..num_messages {
input_items.push(InputItem::Message(InputMessage { input_items.push(InputItem::Message(InputMessage {
role: if i % 2 == 0 { MessageRole::User } else { MessageRole::Assistant }, role: if i % 2 == 0 {
MessageRole::User
} else {
MessageRole::Assistant
},
content: MessageContent::Items(vec![InputContent::InputText { content: MessageContent::Items(vec![InputContent::InputText {
text: format!("Message {}", i), text: format!("Message {}", i),
}]), }]),
@ -252,7 +258,9 @@ mod tests {
let merged = storage.merge(&prev_state, current_input); let merged = storage.merge(&prev_state, current_input);
// Verify order: prev messages first, then current // Verify order: prev messages first, then current
let InputItem::Message(msg) = &merged[0] else { panic!("Expected Message") }; let InputItem::Message(msg) = &merged[0] else {
panic!("Expected Message")
};
match &msg.content { match &msg.content {
MessageContent::Items(items) => match &items[0] { MessageContent::Items(items) => match &items[0] {
InputContent::InputText { text } => assert_eq!(text, "Message 0"), InputContent::InputText { text } => assert_eq!(text, "Message 0"),
@ -261,7 +269,9 @@ mod tests {
_ => panic!("Expected MessageContent::Items"), _ => panic!("Expected MessageContent::Items"),
} }
let InputItem::Message(msg) = &merged[2] else { panic!("Expected Message") }; let InputItem::Message(msg) = &merged[2] else {
panic!("Expected Message")
};
match &msg.content { match &msg.content {
MessageContent::Items(items) => match &items[0] { MessageContent::Items(items) => match &items[0] {
InputContent::InputText { text } => assert_eq!(text, "Message 2"), InputContent::InputText { text } => assert_eq!(text, "Message 2"),
@ -404,7 +414,8 @@ mod tests {
let current_input = vec![InputItem::Message(InputMessage { let current_input = vec![InputItem::Message(InputMessage {
role: MessageRole::User, role: MessageRole::User,
content: MessageContent::Items(vec![InputContent::InputText { content: MessageContent::Items(vec![InputContent::InputText {
text: "Function result: {\"temperature\": 72, \"condition\": \"sunny\"}".to_string(), text: "Function result: {\"temperature\": 72, \"condition\": \"sunny\"}"
.to_string(),
}]), }]),
})]; })];
@ -415,7 +426,9 @@ mod tests {
assert_eq!(merged.len(), 3); assert_eq!(merged.len(), 3);
// Verify the order and content // Verify the order and content
let InputItem::Message(msg1) = &merged[0] else { panic!("Expected Message") }; let InputItem::Message(msg1) = &merged[0] else {
panic!("Expected Message")
};
assert!(matches!(msg1.role, MessageRole::User)); assert!(matches!(msg1.role, MessageRole::User));
match &msg1.content { match &msg1.content {
MessageContent::Items(items) => match &items[0] { MessageContent::Items(items) => match &items[0] {
@ -427,7 +440,9 @@ mod tests {
_ => panic!("Expected MessageContent::Items"), _ => panic!("Expected MessageContent::Items"),
} }
let InputItem::Message(msg2) = &merged[1] else { panic!("Expected Message") }; let InputItem::Message(msg2) = &merged[1] else {
panic!("Expected Message")
};
assert!(matches!(msg2.role, MessageRole::Assistant)); assert!(matches!(msg2.role, MessageRole::Assistant));
match &msg2.content { match &msg2.content {
MessageContent::Items(items) => match &items[0] { MessageContent::Items(items) => match &items[0] {
@ -439,7 +454,9 @@ mod tests {
_ => panic!("Expected MessageContent::Items"), _ => panic!("Expected MessageContent::Items"),
} }
let InputItem::Message(msg3) = &merged[2] else { panic!("Expected Message") }; let InputItem::Message(msg3) = &merged[2] else {
panic!("Expected Message")
};
assert!(matches!(msg3.role, MessageRole::User)); assert!(matches!(msg3.role, MessageRole::User));
match &msg3.content { match &msg3.content {
MessageContent::Items(items) => match &items[0] { MessageContent::Items(items) => match &items[0] {
@ -508,11 +525,15 @@ mod tests {
assert_eq!(merged.len(), 5); assert_eq!(merged.len(), 5);
// Verify first item is original user message // Verify first item is original user message
let InputItem::Message(first) = &merged[0] else { panic!("Expected Message") }; let InputItem::Message(first) = &merged[0] else {
panic!("Expected Message")
};
assert!(matches!(first.role, MessageRole::User)); assert!(matches!(first.role, MessageRole::User));
// Verify last two are function outputs // Verify last two are function outputs
let InputItem::Message(second_last) = &merged[3] else { panic!("Expected Message") }; let InputItem::Message(second_last) = &merged[3] else {
panic!("Expected Message")
};
assert!(matches!(second_last.role, MessageRole::User)); assert!(matches!(second_last.role, MessageRole::User));
match &second_last.content { match &second_last.content {
MessageContent::Items(items) => match &items[0] { MessageContent::Items(items) => match &items[0] {
@ -522,7 +543,9 @@ mod tests {
_ => panic!("Expected MessageContent::Items"), _ => panic!("Expected MessageContent::Items"),
} }
let InputItem::Message(last) = &merged[4] else { panic!("Expected Message") }; let InputItem::Message(last) = &merged[4] else {
panic!("Expected Message")
};
assert!(matches!(last.role, MessageRole::User)); assert!(matches!(last.role, MessageRole::User));
match &last.content { match &last.content {
MessageContent::Items(items) => match &items[0] { MessageContent::Items(items) => match &items[0] {
@ -590,7 +613,9 @@ mod tests {
assert_eq!(merged.len(), 5); assert_eq!(merged.len(), 5);
// Verify the entire conversation flow is preserved // Verify the entire conversation flow is preserved
let InputItem::Message(first) = &merged[0] else { panic!("Expected Message") }; let InputItem::Message(first) = &merged[0] else {
panic!("Expected Message")
};
match &first.content { match &first.content {
MessageContent::Items(items) => match &items[0] { MessageContent::Items(items) => match &items[0] {
InputContent::InputText { text } => assert!(text.contains("What's the weather")), InputContent::InputText { text } => assert!(text.contains("What's the weather")),
@ -599,7 +624,9 @@ mod tests {
_ => panic!("Expected MessageContent::Items"), _ => panic!("Expected MessageContent::Items"),
} }
let InputItem::Message(last) = &merged[4] else { panic!("Expected Message") }; let InputItem::Message(last) = &merged[4] else {
panic!("Expected Message")
};
match &last.content { match &last.content {
MessageContent::Items(items) => match &items[0] { MessageContent::Items(items) => match &items[0] {
InputContent::InputText { text } => assert!(text.contains("umbrella")), InputContent::InputText { text } => assert!(text.contains("umbrella")),

View file

@ -1,14 +1,16 @@
use async_trait::async_trait; use async_trait::async_trait;
use hermesllm::apis::openai_responses::{InputItem, InputMessage, InputContent, MessageContent, MessageRole, InputParam}; use hermesllm::apis::openai_responses::{
InputContent, InputItem, InputMessage, InputParam, MessageContent, MessageRole,
};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::error::Error; use std::error::Error;
use std::fmt; use std::fmt;
use std::sync::Arc; use std::sync::Arc;
use tracing::{debug}; use tracing::debug;
pub mod memory; pub mod memory;
pub mod response_state_processor;
pub mod postgresql; pub mod postgresql;
pub mod response_state_processor;
/// Represents the conversational state for a v1/responses request /// Represents the conversational state for a v1/responses request
/// Contains the complete input/output history that can be restored /// Contains the complete input/output history that can be restored
@ -47,7 +49,9 @@ pub enum StateStorageError {
impl fmt::Display for StateStorageError { impl fmt::Display for StateStorageError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self { match self {
StateStorageError::NotFound(id) => write!(f, "Conversation state not found for response_id: {}", id), StateStorageError::NotFound(id) => {
write!(f, "Conversation state not found for response_id: {}", id)
}
StateStorageError::StorageError(msg) => write!(f, "Storage error: {}", msg), StateStorageError::StorageError(msg) => write!(f, "Storage error: {}", msg),
StateStorageError::SerializationError(msg) => write!(f, "Serialization error: {}", msg), StateStorageError::SerializationError(msg) => write!(f, "Serialization error: {}", msg),
} }
@ -96,8 +100,6 @@ pub trait StateStorage: Send + Sync {
} }
} }
/// Storage backend type enum /// Storage backend type enum
#[derive(Debug, Clone, Copy, PartialEq, Eq)] #[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum StorageBackend { pub enum StorageBackend {
@ -106,7 +108,7 @@ pub enum StorageBackend {
} }
impl StorageBackend { impl StorageBackend {
pub fn from_str(s: &str) -> Option<Self> { pub fn parse_backend(s: &str) -> Option<Self> {
match s.to_lowercase().as_str() { match s.to_lowercase().as_str() {
"memory" => Some(StorageBackend::Memory), "memory" => Some(StorageBackend::Memory),
"supabase" => Some(StorageBackend::Supabase), "supabase" => Some(StorageBackend::Supabase),
@ -139,7 +141,6 @@ pub async fn retrieve_and_combine_input(
previous_response_id: &str, previous_response_id: &str,
current_input: Vec<InputItem>, current_input: Vec<InputItem>,
) -> Result<Vec<InputItem>, StateStorageError> { ) -> Result<Vec<InputItem>, StateStorageError> {
// First get the previous state // First get the previous state
let prev_state = storage.get(previous_response_id).await?; let prev_state = storage.get(previous_response_id).await?;
let combined_input = storage.merge(&prev_state, current_input); let combined_input = storage.merge(&prev_state, current_input);

View file

@ -149,13 +149,12 @@ impl StateStorage for PostgreSQLConversationStorage {
let provider: String = row.get("provider"); let provider: String = row.get("provider");
// Deserialize input_items from JSONB // Deserialize input_items from JSONB
let input_items = let input_items = serde_json::from_value(input_items_json).map_err(|e| {
serde_json::from_value(input_items_json).map_err(|e| { StateStorageError::StorageError(format!(
StateStorageError::StorageError(format!( "Failed to deserialize input_items: {}",
"Failed to deserialize input_items: {}", e
e ))
)) })?;
})?;
Ok(OpenAIConversationState { Ok(OpenAIConversationState {
response_id, response_id,
@ -230,7 +229,9 @@ Run that SQL file against your database before using this storage backend.
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use hermesllm::apis::openai_responses::{InputContent, InputItem, InputMessage, MessageContent, MessageRole}; use hermesllm::apis::openai_responses::{
InputContent, InputItem, InputMessage, MessageContent, MessageRole,
};
fn create_test_state(response_id: &str) -> OpenAIConversationState { fn create_test_state(response_id: &str) -> OpenAIConversationState {
OpenAIConversationState { OpenAIConversationState {
@ -320,7 +321,10 @@ mod tests {
let result = storage.get("nonexistent_id").await; let result = storage.get("nonexistent_id").await;
assert!(result.is_err()); assert!(result.is_err());
assert!(matches!(result.unwrap_err(), StateStorageError::NotFound(_))); assert!(matches!(
result.unwrap_err(),
StateStorageError::NotFound(_)
));
} }
#[tokio::test] #[tokio::test]
@ -372,7 +376,10 @@ mod tests {
let result = storage.delete("nonexistent_id").await; let result = storage.delete("nonexistent_id").await;
assert!(result.is_err()); assert!(result.is_err());
assert!(matches!(result.unwrap_err(), StateStorageError::NotFound(_))); assert!(matches!(
result.unwrap_err(),
StateStorageError::NotFound(_)
));
} }
#[tokio::test] #[tokio::test]
@ -423,9 +430,13 @@ mod tests {
println!("✅ Data written to Supabase!"); println!("✅ Data written to Supabase!");
println!("Check your Supabase dashboard:"); println!("Check your Supabase dashboard:");
println!(" SELECT * FROM conversation_states WHERE response_id = 'manual_test_verification';"); println!(
" SELECT * FROM conversation_states WHERE response_id = 'manual_test_verification';"
);
println!("\nTo cleanup, run:"); println!("\nTo cleanup, run:");
println!(" DELETE FROM conversation_states WHERE response_id = 'manual_test_verification';"); println!(
" DELETE FROM conversation_states WHERE response_id = 'manual_test_verification';"
);
// DON'T cleanup - leave it for manual verification // DON'T cleanup - leave it for manual verification
} }

View file

@ -1,13 +1,11 @@
use bytes::Bytes; use bytes::Bytes;
use flate2::read::GzDecoder; use flate2::read::GzDecoder;
use hermesllm::apis::openai_responses::{ use hermesllm::apis::openai_responses::{InputItem, OutputItem, ResponsesAPIStreamEvent};
InputItem, OutputItem, ResponsesAPIStreamEvent,
};
use hermesllm::apis::streaming_shapes::sse::SseStreamIter; use hermesllm::apis::streaming_shapes::sse::SseStreamIter;
use hermesllm::transforms::response::output_to_input::outputs_to_inputs; use hermesllm::transforms::response::output_to_input::outputs_to_inputs;
use std::io::Read; use std::io::Read;
use std::sync::Arc; use std::sync::Arc;
use tracing::{info, debug, warn}; use tracing::{debug, info, warn};
use crate::handlers::utils::StreamProcessor; use crate::handlers::utils::StreamProcessor;
use crate::state::{OpenAIConversationState, StateStorage}; use crate::state::{OpenAIConversationState, StateStorage};
@ -53,6 +51,7 @@ pub struct ResponsesStateProcessor<P: StreamProcessor> {
} }
impl<P: StreamProcessor> ResponsesStateProcessor<P> { impl<P: StreamProcessor> ResponsesStateProcessor<P> {
#[allow(clippy::too_many_arguments)]
pub fn new( pub fn new(
inner: P, inner: P,
storage: Arc<dyn StateStorage>, storage: Arc<dyn StateStorage>,
@ -139,20 +138,19 @@ impl<P: StreamProcessor> ResponsesStateProcessor<P> {
for event in sse_iter { for event in sse_iter {
// Only process data lines (skip event-only lines) // Only process data lines (skip event-only lines)
if let Some(data_str) = &event.data { if let Some(data_str) = &event.data {
// Try to parse as ResponsesAPIStreamEvent // Try to parse as ResponsesAPIStreamEvent and check if it's a ResponseCompleted event
if let Ok(stream_event) = serde_json::from_str::<ResponsesAPIStreamEvent>(data_str) { if let Ok(ResponsesAPIStreamEvent::ResponseCompleted { response, .. }) =
// Check if this is a ResponseCompleted event serde_json::from_str::<ResponsesAPIStreamEvent>(data_str)
if let ResponsesAPIStreamEvent::ResponseCompleted { response, .. } = stream_event { {
info!( info!(
"[PLANO_REQ_ID:{}] | STATE_PROCESSOR | Captured streaming response.completed: response_id={}, output_items={}", "[PLANO_REQ_ID:{}] | STATE_PROCESSOR | Captured streaming response.completed: response_id={}, output_items={}",
self.request_id, self.request_id,
response.id, response.id,
response.output.len() response.output.len()
); );
self.response_id = Some(response.id.clone()); self.response_id = Some(response.id.clone());
self.output_items = Some(response.output.clone()); self.output_items = Some(response.output.clone());
return; // Found what we need, exit early return; // Found what we need, exit early
}
} }
} }
} }
@ -172,7 +170,9 @@ impl<P: StreamProcessor> ResponsesStateProcessor<P> {
let decompressed = self.decompress_buffer(); let decompressed = self.decompress_buffer();
// Parse complete JSON response // Parse complete JSON response
match serde_json::from_slice::<hermesllm::apis::openai_responses::ResponsesAPIResponse>(&decompressed) { match serde_json::from_slice::<hermesllm::apis::openai_responses::ResponsesAPIResponse>(
&decompressed,
) {
Ok(response) => { Ok(response) => {
info!( info!(
"[PLANO_REQ_ID:{}] | STATE_PROCESSOR | Captured non-streaming response: response_id={}, output_items={}", "[PLANO_REQ_ID:{}] | STATE_PROCESSOR | Captured non-streaming response: response_id={}, output_items={}",

View file

@ -2,11 +2,9 @@
/// ///
/// This module defines standard attribute keys following OTEL semantic conventions. /// This module defines standard attribute keys following OTEL semantic conventions.
/// See: https://opentelemetry.io/docs/specs/semconv/ /// See: https://opentelemetry.io/docs/specs/semconv/
// ============================================================================= // =============================================================================
// Span Attributes - HTTP // Span Attributes - HTTP
// ============================================================================= // =============================================================================
/// Semantic conventions for HTTP-related span attributes /// Semantic conventions for HTTP-related span attributes
pub mod http { pub mod http {
/// HTTP request method /// HTTP request method

View file

@ -1,3 +1,3 @@
mod constants; mod constants;
pub use constants::{OperationNameBuilder, operation_component, http, llm, error, routing}; pub use constants::{error, http, llm, operation_component, routing, OperationNameBuilder};

View file

@ -295,11 +295,14 @@ impl serde::Serialize for OrchestrationPreference {
let mut state = serializer.serialize_struct("OrchestrationPreference", 3)?; let mut state = serializer.serialize_struct("OrchestrationPreference", 3)?;
state.serialize_field("name", &self.name)?; state.serialize_field("name", &self.name)?;
state.serialize_field("description", &self.description)?; state.serialize_field("description", &self.description)?;
state.serialize_field("parameters", &serde_json::json!({ state.serialize_field(
"type": "object", "parameters",
"properties": {}, &serde_json::json!({
"required": [] "type": "object",
}))?; "properties": {},
"required": []
}),
)?;
state.end() state.end()
} }
} }
@ -489,7 +492,10 @@ mod test {
assert_eq!(config.version, "v0.3.0"); assert_eq!(config.version, "v0.3.0");
if let Some(prompt_targets) = &config.prompt_targets { if let Some(prompt_targets) = &config.prompt_targets {
assert!(!prompt_targets.is_empty(), "prompt_targets should not be empty if present"); assert!(
!prompt_targets.is_empty(),
"prompt_targets should not be empty if present"
);
} }
if let Some(tracing) = config.tracing.as_ref() { if let Some(tracing) = config.tracing.as_ref() {
@ -510,19 +516,48 @@ mod test {
.expect("reference config file not found"); .expect("reference config file not found");
let config: super::Configuration = serde_yaml::from_str(&ref_config).unwrap(); let config: super::Configuration = serde_yaml::from_str(&ref_config).unwrap();
if let Some(prompt_targets) = &config.prompt_targets { if let Some(prompt_targets) = &config.prompt_targets {
if let Some(prompt_target) = prompt_targets.iter().find(|p| p.name == "reboot_network_device") { if let Some(prompt_target) = prompt_targets
.iter()
.find(|p| p.name == "reboot_network_device")
{
let chat_completion_tool: super::ChatCompletionTool = prompt_target.into(); let chat_completion_tool: super::ChatCompletionTool = prompt_target.into();
assert_eq!(chat_completion_tool.tool_type, ToolType::Function); assert_eq!(chat_completion_tool.tool_type, ToolType::Function);
assert_eq!(chat_completion_tool.function.name, "reboot_network_device"); assert_eq!(chat_completion_tool.function.name, "reboot_network_device");
assert_eq!(chat_completion_tool.function.description, "Reboot a specific network device"); assert_eq!(
chat_completion_tool.function.description,
"Reboot a specific network device"
);
assert_eq!(chat_completion_tool.function.parameters.properties.len(), 2); assert_eq!(chat_completion_tool.function.parameters.properties.len(), 2);
assert!(chat_completion_tool.function.parameters.properties.contains_key("device_id")); assert!(chat_completion_tool
let device_id_param = chat_completion_tool.function.parameters.properties.get("device_id").unwrap(); .function
assert_eq!(device_id_param.parameter_type, crate::api::open_ai::ParameterType::String); .parameters
assert_eq!(device_id_param.description, "Identifier of the network device to reboot.".to_string()); .properties
.contains_key("device_id"));
let device_id_param = chat_completion_tool
.function
.parameters
.properties
.get("device_id")
.unwrap();
assert_eq!(
device_id_param.parameter_type,
crate::api::open_ai::ParameterType::String
);
assert_eq!(
device_id_param.description,
"Identifier of the network device to reboot.".to_string()
);
assert_eq!(device_id_param.required, Some(true)); assert_eq!(device_id_param.required, Some(true));
let confirmation_param = chat_completion_tool.function.parameters.properties.get("confirmation").unwrap(); let confirmation_param = chat_completion_tool
assert_eq!(confirmation_param.parameter_type, crate::api::open_ai::ParameterType::Bool); .function
.parameters
.properties
.get("confirmation")
.unwrap();
assert_eq!(
confirmation_param.parameter_type,
crate::api::open_ai::ParameterType::Bool
);
} }
} }
} }

View file

@ -32,6 +32,6 @@ pub const OTEL_COLLECTOR_HTTP: &str = "opentelemetry_collector_http";
pub const OTEL_POST_PATH: &str = "/v1/traces"; pub const OTEL_POST_PATH: &str = "/v1/traces";
pub const LLM_ROUTE_HEADER: &str = "x-arch-llm-route"; pub const LLM_ROUTE_HEADER: &str = "x-arch-llm-route";
pub const ENVOY_RETRY_HEADER: &str = "x-envoy-max-retries"; pub const ENVOY_RETRY_HEADER: &str = "x-envoy-max-retries";
pub const BRIGHT_STAFF_SERVICE_NAME : &str = "brightstaff"; pub const BRIGHT_STAFF_SERVICE_NAME: &str = "brightstaff";
pub const PLANO_ORCHESTRATOR_MODEL_NAME: &str = "Plano-Orchestrator"; pub const PLANO_ORCHESTRATOR_MODEL_NAME: &str = "Plano-Orchestrator";
pub const ARCH_FC_CLUSTER: &str = "arch"; pub const ARCH_FC_CLUSTER: &str = "arch";

View file

@ -10,6 +10,6 @@ pub mod ratelimit;
pub mod routing; pub mod routing;
pub mod stats; pub mod stats;
pub mod tokenizer; pub mod tokenizer;
pub mod tracing;
pub mod traces; pub mod traces;
pub mod tracing;
pub mod utils; pub mod utils;

View file

@ -41,7 +41,8 @@ pub fn get_llm_provider(
llm_providers llm_providers
.iter() .iter()
.filter(|(_, provider)| { .filter(|(_, provider)| {
provider.model provider
.model
.as_ref() .as_ref()
.map(|m| !m.starts_with("Arch")) .map(|m| !m.starts_with("Arch"))
.unwrap_or(true) .unwrap_or(true)

View file

@ -1,5 +1,5 @@
use super::shapes::Span;
use super::resource_span_builder::ResourceSpanBuilder; use super::resource_span_builder::ResourceSpanBuilder;
use super::shapes::Span;
use std::collections::{HashMap, VecDeque}; use std::collections::{HashMap, VecDeque};
use std::sync::Arc; use std::sync::Arc;
use tokio::sync::Mutex; use tokio::sync::Mutex;
@ -160,7 +160,11 @@ impl TraceCollector {
} }
let total_spans: usize = service_batches.iter().map(|(_, spans)| spans.len()).sum(); let total_spans: usize = service_batches.iter().map(|(_, spans)| spans.len()).sum();
debug!("Flushing {} spans across {} services to OTEL collector", total_spans, service_batches.len()); debug!(
"Flushing {} spans across {} services to OTEL collector",
total_spans,
service_batches.len()
);
// Build canonical OTEL payload structure - one ResourceSpan per service // Build canonical OTEL payload structure - one ResourceSpan per service
let resource_spans = self.build_resource_spans(service_batches); let resource_spans = self.build_resource_spans(service_batches);
@ -178,7 +182,10 @@ impl TraceCollector {
} }
/// Build OTEL-compliant resource spans from collected spans, one ResourceSpan per service /// Build OTEL-compliant resource spans from collected spans, one ResourceSpan per service
fn build_resource_spans(&self, service_batches: Vec<(String, Vec<Span>)>) -> Vec<super::shapes::ResourceSpan> { fn build_resource_spans(
&self,
service_batches: Vec<(String, Vec<Span>)>,
) -> Vec<super::shapes::ResourceSpan> {
service_batches service_batches
.into_iter() .into_iter()
.map(|(service_name, spans)| { .map(|(service_name, spans)| {

View file

@ -1,7 +1,6 @@
/// OpenTelemetry semantic convention constants for tracing /// OpenTelemetry semantic convention constants for tracing
/// ///
/// These constants ensure consistency across the codebase and prevent typos /// These constants ensure consistency across the codebase and prevent typos
/// Resource attribute keys following OTEL semantic conventions /// Resource attribute keys following OTEL semantic conventions
pub mod resource { pub mod resource {
/// Logical name of the service /// Logical name of the service

View file

@ -1,9 +1,9 @@
// Original tracing types (OTEL structures) // Original tracing types (OTEL structures)
mod shapes; mod shapes;
// New tracing utilities // New tracing utilities
mod span_builder;
mod resource_span_builder;
mod constants; mod constants;
mod resource_span_builder;
mod span_builder;
#[cfg(feature = "trace-collection")] #[cfg(feature = "trace-collection")]
mod collector; mod collector;
@ -13,14 +13,14 @@ mod tests;
// Re-export original types // Re-export original types
pub use shapes::{ pub use shapes::{
Span, Event, Traceparent, TraceparentNewError, Attribute, AttributeValue, Event, Resource, ResourceSpan, Scope, ScopeSpan, Span, Traceparent,
ResourceSpan, Resource, ScopeSpan, Scope, Attribute, AttributeValue, TraceparentNewError,
}; };
// Re-export new utilities // Re-export new utilities
pub use span_builder::{SpanBuilder, SpanKind, generate_random_span_id};
pub use resource_span_builder::ResourceSpanBuilder;
pub use constants::*; pub use constants::*;
pub use resource_span_builder::ResourceSpanBuilder;
pub use span_builder::{generate_random_span_id, SpanBuilder, SpanKind};
#[cfg(feature = "trace-collection")] #[cfg(feature = "trace-collection")]
pub use collector::{TraceCollector, parse_traceparent}; pub use collector::{parse_traceparent, TraceCollector};

View file

@ -1,5 +1,5 @@
use super::shapes::{ResourceSpan, Resource, ScopeSpan, Scope, Span, Attribute, AttributeValue};
use super::constants::{resource, scope}; use super::constants::{resource, scope};
use super::shapes::{Attribute, AttributeValue, Resource, ResourceSpan, Scope, ScopeSpan, Span};
use std::collections::HashMap; use std::collections::HashMap;
/// Builder for creating OTEL ResourceSpan structures /// Builder for creating OTEL ResourceSpan structures
@ -26,7 +26,11 @@ impl ResourceSpanBuilder {
} }
/// Add a resource attribute (e.g., deployment.environment, host.name) /// Add a resource attribute (e.g., deployment.environment, host.name)
pub fn with_resource_attribute(mut self, key: impl Into<String>, value: impl Into<String>) -> Self { pub fn with_resource_attribute(
mut self,
key: impl Into<String>,
value: impl Into<String>,
) -> Self {
self.resource_attributes.insert(key.into(), value.into()); self.resource_attributes.insert(key.into(), value.into());
self self
} }
@ -58,14 +62,12 @@ impl ResourceSpanBuilder {
/// Build the ResourceSpan /// Build the ResourceSpan
pub fn build(self) -> ResourceSpan { pub fn build(self) -> ResourceSpan {
// Build resource attributes // Build resource attributes
let mut attributes = vec![ let mut attributes = vec![Attribute {
Attribute { key: resource::SERVICE_NAME.to_string(),
key: resource::SERVICE_NAME.to_string(), value: AttributeValue {
value: AttributeValue { string_value: Some(self.service_name),
string_value: Some(self.service_name), },
}, }];
}
];
// Add custom resource attributes // Add custom resource attributes
for (key, value) in self.resource_attributes { for (key, value) in self.resource_attributes {

View file

@ -1,4 +1,4 @@
use super::shapes::{Span, Attribute, AttributeValue}; use super::shapes::{Attribute, AttributeValue, Span};
use std::collections::HashMap; use std::collections::HashMap;
use std::time::SystemTime; use std::time::SystemTime;
@ -116,10 +116,11 @@ impl SpanBuilder {
let end_nanos = system_time_to_nanos(end_time); let end_nanos = system_time_to_nanos(end_time);
// Generate trace_id if not provided // Generate trace_id if not provided
let trace_id = self.trace_id.unwrap_or_else(|| generate_random_trace_id()); let trace_id = self.trace_id.unwrap_or_else(generate_random_trace_id);
// Create attributes in OTEL format // Create attributes in OTEL format
let attributes: Vec<Attribute> = self.attributes let attributes: Vec<Attribute> = self
.attributes
.into_iter() .into_iter()
.map(|(key, value)| Attribute { .map(|(key, value)| Attribute {
key, key,
@ -132,7 +133,7 @@ impl SpanBuilder {
// Build span directly without going through Span::new() // Build span directly without going through Span::new()
Span { Span {
trace_id, trace_id,
span_id: self.span_id.unwrap_or_else(|| generate_random_span_id()), span_id: self.span_id.unwrap_or_else(generate_random_span_id),
parent_span_id: self.parent_span_id, parent_span_id: self.parent_span_id,
name: self.name, name: self.name,
start_time_unix_nano: format!("{}", start_nanos), start_time_unix_nano: format!("{}", start_nanos),

View file

@ -21,10 +21,7 @@ use tokio::sync::RwLock;
type SharedTraces = Arc<RwLock<Vec<Value>>>; type SharedTraces = Arc<RwLock<Vec<Value>>>;
/// POST /v1/traces - capture incoming OTLP payload /// POST /v1/traces - capture incoming OTLP payload
async fn post_traces( async fn post_traces(State(traces): State<SharedTraces>, Json(payload): Json<Value>) -> StatusCode {
State(traces): State<SharedTraces>,
Json(payload): Json<Value>,
) -> StatusCode {
traces.write().await.push(payload); traces.write().await.push(payload);
StatusCode::OK StatusCode::OK
} }
@ -67,9 +64,7 @@ impl MockOtelCollector {
let address = format!("http://127.0.0.1:{}", addr.port()); let address = format!("http://127.0.0.1:{}", addr.port());
let server_handle = tokio::spawn(async move { let server_handle = tokio::spawn(async move {
axum::serve(listener, app) axum::serve(listener, app).await.expect("Server failed");
.await
.expect("Server failed");
}); });
// Give server a moment to start // Give server a moment to start

View file

@ -36,9 +36,12 @@ fn extract_spans(payloads: &[Value]) -> Vec<&Value> {
for payload in payloads { for payload in payloads {
if let Some(resource_spans) = payload.get("resourceSpans").and_then(|v| v.as_array()) { if let Some(resource_spans) = payload.get("resourceSpans").and_then(|v| v.as_array()) {
for resource_span in resource_spans { for resource_span in resource_spans {
if let Some(scope_spans) = resource_span.get("scopeSpans").and_then(|v| v.as_array()) { if let Some(scope_spans) =
resource_span.get("scopeSpans").and_then(|v| v.as_array())
{
for scope_span in scope_spans { for scope_span in scope_spans {
if let Some(span_list) = scope_span.get("spans").and_then(|v| v.as_array()) { if let Some(span_list) = scope_span.get("spans").and_then(|v| v.as_array())
{
spans.extend(span_list.iter()); spans.extend(span_list.iter());
} }
} }
@ -54,9 +57,9 @@ fn get_string_attr<'a>(span: &'a Value, key: &str) -> Option<&'a str> {
span.get("attributes") span.get("attributes")
.and_then(|attrs| attrs.as_array()) .and_then(|attrs| attrs.as_array())
.and_then(|attrs| { .and_then(|attrs| {
attrs.iter().find(|attr| { attrs
attr.get("key").and_then(|k| k.as_str()) == Some(key) .iter()
}) .find(|attr| attr.get("key").and_then(|k| k.as_str()) == Some(key))
}) })
.and_then(|attr| attr.get("value")) .and_then(|attr| attr.get("value"))
.and_then(|v| v.get("stringValue")) .and_then(|v| v.get("stringValue"))
@ -70,7 +73,10 @@ async fn test_llm_span_contains_basic_attributes() {
let mock_collector = MockOtelCollector::start().await; let mock_collector = MockOtelCollector::start().await;
// Create TraceCollector pointing to mock with 500ms flush intervalc // Create TraceCollector pointing to mock with 500ms flush intervalc
std::env::set_var("OTEL_COLLECTOR_URL", format!("{}/v1/traces", mock_collector.address())); std::env::set_var(
"OTEL_COLLECTOR_URL",
format!("{}/v1/traces", mock_collector.address()),
);
std::env::set_var("OTEL_TRACING_ENABLED", "true"); std::env::set_var("OTEL_TRACING_ENABLED", "true");
std::env::set_var("TRACE_FLUSH_INTERVAL_MS", "500"); std::env::set_var("TRACE_FLUSH_INTERVAL_MS", "500");
let trace_collector = Arc::new(TraceCollector::new(Some(true))); let trace_collector = Arc::new(TraceCollector::new(Some(true)));
@ -102,7 +108,10 @@ async fn test_llm_span_contains_basic_attributes() {
let span = spans[0]; let span = spans[0];
// Validate HTTP attributes // Validate HTTP attributes
assert_eq!(get_string_attr(span, "http.method"), Some("POST")); assert_eq!(get_string_attr(span, "http.method"), Some("POST"));
assert_eq!(get_string_attr(span, "http.target"), Some("/v1/chat/completions")); assert_eq!(
get_string_attr(span, "http.target"),
Some("/v1/chat/completions")
);
// Validate LLM attributes // Validate LLM attributes
assert_eq!(get_string_attr(span, "llm.model"), Some("gpt-4o")); assert_eq!(get_string_attr(span, "llm.model"), Some("gpt-4o"));
@ -115,7 +124,10 @@ async fn test_llm_span_contains_basic_attributes() {
#[serial] #[serial]
async fn test_llm_span_contains_tool_information() { async fn test_llm_span_contains_tool_information() {
let mock_collector = MockOtelCollector::start().await; let mock_collector = MockOtelCollector::start().await;
std::env::set_var("OTEL_COLLECTOR_URL", format!("{}/v1/traces", mock_collector.address())); std::env::set_var(
"OTEL_COLLECTOR_URL",
format!("{}/v1/traces", mock_collector.address()),
);
std::env::set_var("OTEL_TRACING_ENABLED", "true"); std::env::set_var("OTEL_TRACING_ENABLED", "true");
std::env::set_var("TRACE_FLUSH_INTERVAL_MS", "500"); std::env::set_var("TRACE_FLUSH_INTERVAL_MS", "500");
let trace_collector = Arc::new(TraceCollector::new(Some(true))); let trace_collector = Arc::new(TraceCollector::new(Some(true)));
@ -144,19 +156,26 @@ async fn test_llm_span_contains_tool_information() {
assert!(tools.unwrap().contains("get_weather(...)")); assert!(tools.unwrap().contains("get_weather(...)"));
assert!(tools.unwrap().contains("search_web(...)")); assert!(tools.unwrap().contains("search_web(...)"));
assert!(tools.unwrap().contains("calculate(...)")); assert!(tools.unwrap().contains("calculate(...)"));
assert!(tools.unwrap().contains('\n'), "Tools should be newline-separated"); assert!(
tools.unwrap().contains('\n'),
"Tools should be newline-separated"
);
} }
#[tokio::test] #[tokio::test]
#[serial] #[serial]
async fn test_llm_span_contains_user_message_preview() { async fn test_llm_span_contains_user_message_preview() {
let mock_collector = MockOtelCollector::start().await; let mock_collector = MockOtelCollector::start().await;
std::env::set_var("OTEL_COLLECTOR_URL", format!("{}/v1/traces", mock_collector.address())); std::env::set_var(
"OTEL_COLLECTOR_URL",
format!("{}/v1/traces", mock_collector.address()),
);
std::env::set_var("OTEL_TRACING_ENABLED", "true"); std::env::set_var("OTEL_TRACING_ENABLED", "true");
std::env::set_var("TRACE_FLUSH_INTERVAL_MS", "500"); std::env::set_var("TRACE_FLUSH_INTERVAL_MS", "500");
let trace_collector = Arc::new(TraceCollector::new(Some(true))); let trace_collector = Arc::new(TraceCollector::new(Some(true)));
let long_message = "This is a very long user message that should be truncated to 50 characters in the span"; let long_message =
"This is a very long user message that should be truncated to 50 characters in the span";
let preview = if long_message.len() > 50 { let preview = if long_message.len() > 50 {
format!("{}...", &long_message[..50]) format!("{}...", &long_message[..50])
} else { } else {
@ -187,7 +206,10 @@ async fn test_llm_span_contains_user_message_preview() {
#[serial] #[serial]
async fn test_llm_span_contains_time_to_first_token() { async fn test_llm_span_contains_time_to_first_token() {
let mock_collector = MockOtelCollector::start().await; let mock_collector = MockOtelCollector::start().await;
std::env::set_var("OTEL_COLLECTOR_URL", format!("{}/v1/traces", mock_collector.address())); std::env::set_var(
"OTEL_COLLECTOR_URL",
format!("{}/v1/traces", mock_collector.address()),
);
std::env::set_var("OTEL_TRACING_ENABLED", "true"); std::env::set_var("OTEL_TRACING_ENABLED", "true");
std::env::set_var("TRACE_FLUSH_INTERVAL_MS", "500"); std::env::set_var("TRACE_FLUSH_INTERVAL_MS", "500");
let trace_collector = Arc::new(TraceCollector::new(Some(true))); let trace_collector = Arc::new(TraceCollector::new(Some(true)));
@ -217,7 +239,10 @@ async fn test_llm_span_contains_time_to_first_token() {
#[serial] #[serial]
async fn test_llm_span_contains_upstream_path() { async fn test_llm_span_contains_upstream_path() {
let mock_collector = MockOtelCollector::start().await; let mock_collector = MockOtelCollector::start().await;
std::env::set_var("OTEL_COLLECTOR_URL", format!("{}/v1/traces", mock_collector.address())); std::env::set_var(
"OTEL_COLLECTOR_URL",
format!("{}/v1/traces", mock_collector.address()),
);
std::env::set_var("OTEL_TRACING_ENABLED", "true"); std::env::set_var("OTEL_TRACING_ENABLED", "true");
std::env::set_var("TRACE_FLUSH_INTERVAL_MS", "500"); std::env::set_var("TRACE_FLUSH_INTERVAL_MS", "500");
let trace_collector = Arc::new(TraceCollector::new(Some(true))); let trace_collector = Arc::new(TraceCollector::new(Some(true)));
@ -241,7 +266,10 @@ async fn test_llm_span_contains_upstream_path() {
// Operation name should show the transformation // Operation name should show the transformation
let name = span.get("name").and_then(|v| v.as_str()); let name = span.get("name").and_then(|v| v.as_str());
assert!(name.is_some()); assert!(name.is_some());
assert!(name.unwrap().contains(">>"), "Operation name should show path transformation"); assert!(
name.unwrap().contains(">>"),
"Operation name should show path transformation"
);
// Check upstream target attribute // Check upstream target attribute
let upstream = get_string_attr(span, "http.upstream_target"); let upstream = get_string_attr(span, "http.upstream_target");
@ -252,7 +280,10 @@ async fn test_llm_span_contains_upstream_path() {
#[serial] #[serial]
async fn test_llm_span_multiple_services() { async fn test_llm_span_multiple_services() {
let mock_collector = MockOtelCollector::start().await; let mock_collector = MockOtelCollector::start().await;
std::env::set_var("OTEL_COLLECTOR_URL", format!("{}/v1/traces", mock_collector.address())); std::env::set_var(
"OTEL_COLLECTOR_URL",
format!("{}/v1/traces", mock_collector.address()),
);
std::env::set_var("OTEL_TRACING_ENABLED", "true"); std::env::set_var("OTEL_TRACING_ENABLED", "true");
std::env::set_var("TRACE_FLUSH_INTERVAL_MS", "500"); std::env::set_var("TRACE_FLUSH_INTERVAL_MS", "500");
let trace_collector = Arc::new(TraceCollector::new(Some(true))); let trace_collector = Arc::new(TraceCollector::new(Some(true)));
@ -285,7 +316,10 @@ async fn test_tracing_disabled_produces_no_spans() {
let mock_collector = MockOtelCollector::start().await; let mock_collector = MockOtelCollector::start().await;
// Create TraceCollector with tracing DISABLED // Create TraceCollector with tracing DISABLED
std::env::set_var("OTEL_COLLECTOR_URL", format!("{}/v1/traces", mock_collector.address())); std::env::set_var(
"OTEL_COLLECTOR_URL",
format!("{}/v1/traces", mock_collector.address()),
);
std::env::set_var("OTEL_TRACING_ENABLED", "false"); std::env::set_var("OTEL_TRACING_ENABLED", "false");
std::env::set_var("TRACE_FLUSH_INTERVAL_MS", "500"); std::env::set_var("TRACE_FLUSH_INTERVAL_MS", "500");
let trace_collector = Arc::new(TraceCollector::new(Some(false))); let trace_collector = Arc::new(TraceCollector::new(Some(false)));
@ -300,5 +334,9 @@ async fn test_tracing_disabled_produces_no_spans() {
let payloads = mock_collector.get_traces().await; let payloads = mock_collector.get_traces().await;
let all_spans = extract_spans(&payloads); let all_spans = extract_spans(&payloads);
assert_eq!(all_spans.len(), 0, "No spans should be captured when tracing is disabled"); assert_eq!(
all_spans.len(),
0,
"No spans should be captured when tracing is disabled"
);
} }

View file

@ -161,13 +161,12 @@ impl TraceData {
} }
pub fn new_with_service_name(service_name: String) -> Self { pub fn new_with_service_name(service_name: String) -> Self {
let mut resource_attributes = Vec::new(); let resource_attributes = vec![Attribute {
resource_attributes.push(Attribute {
key: "service.name".to_string(), key: "service.name".to_string(),
value: AttributeValue { value: AttributeValue {
string_value: Some(service_name), string_value: Some(service_name),
}, },
}); }];
let resource = Resource { let resource = Resource {
attributes: resource_attributes, attributes: resource_attributes,
@ -194,7 +193,9 @@ impl TraceData {
pub fn add_span(&mut self, span: Span) { pub fn add_span(&mut self, span: Span) {
if self.resource_spans.is_empty() { if self.resource_spans.is_empty() {
let resource = Resource { attributes: Vec::new() }; let resource = Resource {
attributes: Vec::new(),
};
let scope_span = ScopeSpan { let scope_span = ScopeSpan {
scope: Scope { scope: Scope {
name: "default".to_string(), name: "default".to_string(),

View file

@ -66,7 +66,7 @@ impl ApiDefinition for AmazonBedrockApi {
/// Amazon Bedrock Converse request /// Amazon Bedrock Converse request
#[skip_serializing_none] #[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone)] #[derive(Serialize, Deserialize, Debug, Clone, Default)]
pub struct ConverseRequest { pub struct ConverseRequest {
/// The model ID or ARN to invoke /// The model ID or ARN to invoke
pub model_id: String, pub model_id: String,
@ -91,7 +91,7 @@ pub struct ConverseRequest {
pub additional_model_response_field_paths: Option<Vec<String>>, pub additional_model_response_field_paths: Option<Vec<String>>,
/// Performance configuration /// Performance configuration
#[serde(rename = "performanceConfig")] #[serde(rename = "performanceConfig")]
pub performance_config: Option<PerformanceConfiguration>, pub performance_config: Option<InferenceConfiguration>,
/// Prompt variables for Prompt management /// Prompt variables for Prompt management
#[serde(rename = "promptVariables")] #[serde(rename = "promptVariables")]
pub prompt_variables: Option<HashMap<String, PromptVariableValues>>, pub prompt_variables: Option<HashMap<String, PromptVariableValues>>,
@ -105,26 +105,6 @@ pub struct ConverseRequest {
pub stream: bool, pub stream: bool,
} }
impl Default for ConverseRequest {
fn default() -> Self {
Self {
model_id: String::new(),
messages: None,
system: None,
inference_config: None,
tool_config: None,
guardrail_config: None,
additional_model_request_fields: None,
additional_model_response_field_paths: None,
performance_config: None,
prompt_variables: None,
request_metadata: None,
metadata: None,
stream: false,
}
}
}
/// Amazon Bedrock ConverseStream request (same structure as Converse) /// Amazon Bedrock ConverseStream request (same structure as Converse)
pub type ConverseStreamRequest = ConverseRequest; pub type ConverseStreamRequest = ConverseRequest;
@ -204,8 +184,8 @@ impl ProviderRequest for ConverseRequest {
self.tool_config.as_ref()?.tools.as_ref().map(|tools| { self.tool_config.as_ref()?.tools.as_ref().map(|tools| {
tools tools
.iter() .iter()
.filter_map(|tool| match tool { .map(|tool| match tool {
Tool::ToolSpec { tool_spec } => Some(tool_spec.name.clone()), Tool::ToolSpec { tool_spec } => tool_spec.name.clone(),
}) })
.collect() .collect()
}) })
@ -242,17 +222,14 @@ impl ProviderRequest for ConverseRequest {
// Add system messages if present // Add system messages if present
if let Some(system) = &self.system { if let Some(system) = &self.system {
for sys_block in system { for sys_block in system {
match sys_block { if let SystemContentBlock::Text { text } = sys_block {
SystemContentBlock::Text { text } => { openai_messages.push(Message {
openai_messages.push(Message { role: Role::System,
role: Role::System, content: MessageContent::Text(text.clone()),
content: MessageContent::Text(text.clone()), name: None,
name: None, tool_calls: None,
tool_calls: None, tool_call_id: None,
tool_call_id: None, });
});
}
_ => {} // Skip other system content types
} }
} }
} }
@ -266,7 +243,9 @@ impl ProviderRequest for ConverseRequest {
}; };
// Extract text from content blocks // Extract text from content blocks
let content = msg.content.iter() let content = msg
.content
.iter()
.filter_map(|block| { .filter_map(|block| {
if let ContentBlock::Text { text } = block { if let ContentBlock::Text { text } = block {
Some(text.clone()) Some(text.clone())
@ -311,16 +290,14 @@ impl ProviderRequest for ConverseRequest {
_ => continue, _ => continue,
}; };
let content = if let crate::apis::openai::MessageContent::Text(text) = &msg.content { let content =
vec![ContentBlock::Text { text: text.clone() }] if let crate::apis::openai::MessageContent::Text(text) = &msg.content {
} else { vec![ContentBlock::Text { text: text.clone() }]
vec![] } else {
}; vec![]
};
bedrock_messages.push(crate::apis::amazon_bedrock::Message { bedrock_messages.push(crate::apis::amazon_bedrock::Message { role, content });
role,
content,
});
} }
_ => {} _ => {}
} }
@ -369,7 +346,7 @@ pub enum ConverseStreamEvent {
ContentBlockDelta(ContentBlockDeltaEvent), ContentBlockDelta(ContentBlockDeltaEvent),
ContentBlockStop(ContentBlockStopEvent), ContentBlockStop(ContentBlockStopEvent),
MessageStop(MessageStopEvent), MessageStop(MessageStopEvent),
Metadata(ConverseStreamMetadataEvent), Metadata(Box<ConverseStreamMetadataEvent>),
// Error events // Error events
InternalServerException(BedrockException), InternalServerException(BedrockException),
ModelStreamErrorException(BedrockException), ModelStreamErrorException(BedrockException),
@ -1063,7 +1040,7 @@ impl TryFrom<&aws_smithy_eventstream::frame::DecodedFrame> for ConverseStreamEve
"metadata" => { "metadata" => {
let event: ConverseStreamMetadataEvent = let event: ConverseStreamMetadataEvent =
serde_json::from_slice(payload).map_err(BedrockError::Serialization)?; serde_json::from_slice(payload).map_err(BedrockError::Serialization)?;
Ok(ConverseStreamEvent::Metadata(event)) Ok(ConverseStreamEvent::Metadata(Box::new(event)))
} }
unknown => Err(BedrockError::Validation { unknown => Err(BedrockError::Validation {
message: format!("Unknown event type: {}", unknown), message: format!("Unknown event type: {}", unknown),
@ -1106,10 +1083,10 @@ impl TryFrom<&aws_smithy_eventstream::frame::DecodedFrame> for ConverseStreamEve
} }
} }
impl Into<String> for ConverseStreamEvent { impl From<ConverseStreamEvent> for String {
fn into(self) -> String { fn from(val: ConverseStreamEvent) -> String {
let transformed_json = serde_json::to_string(&self).unwrap_or_default(); let transformed_json = serde_json::to_string(&val).unwrap_or_default();
let event_type = match &self { let event_type = match &val {
ConverseStreamEvent::MessageStart { .. } => "message_start", ConverseStreamEvent::MessageStart { .. } => "message_start",
ConverseStreamEvent::ContentBlockStart { .. } => "content_block_start", ConverseStreamEvent::ContentBlockStart { .. } => "content_block_start",
ConverseStreamEvent::ContentBlockDelta { .. } => "content_block_delta", ConverseStreamEvent::ContentBlockDelta { .. } => "content_block_delta",

File diff suppressed because one or more lines are too long

View file

@ -286,7 +286,6 @@ pub struct ImageUrl {
} }
/// A single message in a chat conversation /// A single message in a chat conversation
/// A tool call made by the assistant /// A tool call made by the assistant
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] #[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct ToolCall { pub struct ToolCall {
@ -388,7 +387,7 @@ pub enum StaticContentType {
/// Chat completions API response /// Chat completions API response
#[skip_serializing_none] #[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone)] #[derive(Serialize, Deserialize, Debug, Clone, Default)]
pub struct ChatCompletionsResponse { pub struct ChatCompletionsResponse {
pub id: String, pub id: String,
pub object: Option<String>, pub object: Option<String>,
@ -402,22 +401,6 @@ pub struct ChatCompletionsResponse {
pub metadata: Option<HashMap<String, Value>>, pub metadata: Option<HashMap<String, Value>>,
} }
impl Default for ChatCompletionsResponse {
fn default() -> Self {
ChatCompletionsResponse {
id: String::new(),
object: None,
created: 0,
model: String::new(),
choices: vec![],
usage: Usage::default(),
system_fingerprint: None,
service_tier: None,
metadata: None,
}
}
}
/// Finish reason for completion /// Finish reason for completion
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
@ -431,7 +414,7 @@ pub enum FinishReason {
/// Token usage information /// Token usage information
#[skip_serializing_none] #[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone)] #[derive(Serialize, Deserialize, Debug, Clone, Default)]
pub struct Usage { pub struct Usage {
pub prompt_tokens: u32, pub prompt_tokens: u32,
pub completion_tokens: u32, pub completion_tokens: u32,
@ -440,18 +423,6 @@ pub struct Usage {
pub completion_tokens_details: Option<CompletionTokensDetails>, pub completion_tokens_details: Option<CompletionTokensDetails>,
} }
impl Default for Usage {
fn default() -> Self {
Usage {
prompt_tokens: 0,
completion_tokens: 0,
total_tokens: 0,
prompt_tokens_details: None,
completion_tokens_details: None,
}
}
}
/// Detailed breakdown of prompt tokens /// Detailed breakdown of prompt tokens
#[skip_serializing_none] #[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone)] #[derive(Serialize, Deserialize, Debug, Clone)]
@ -472,7 +443,7 @@ pub struct CompletionTokensDetails {
/// A single choice in the response /// A single choice in the response
#[skip_serializing_none] #[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone)] #[derive(Serialize, Deserialize, Debug, Clone, Default)]
pub struct Choice { pub struct Choice {
pub index: u32, pub index: u32,
pub message: ResponseMessage, pub message: ResponseMessage,
@ -480,17 +451,6 @@ pub struct Choice {
pub logprobs: Option<Value>, pub logprobs: Option<Value>,
} }
impl Default for Choice {
fn default() -> Self {
Choice {
index: 0,
message: ResponseMessage::default(),
finish_reason: None,
logprobs: None,
}
}
}
// ============================================================================ // ============================================================================
// STREAMING API TYPES // STREAMING API TYPES
// ============================================================================ // ============================================================================
@ -608,7 +568,6 @@ pub enum OpenAIError {
// ============================================================================ // ============================================================================
/// Trait Implementations /// Trait Implementations
/// =========================================================================== /// ===========================================================================
/// Parameterized conversion for ChatCompletionsRequest /// Parameterized conversion for ChatCompletionsRequest
impl TryFrom<&[u8]> for ChatCompletionsRequest { impl TryFrom<&[u8]> for ChatCompletionsRequest {
type Error = OpenAIStreamError; type Error = OpenAIStreamError;
@ -721,7 +680,7 @@ impl ProviderRequest for ChatCompletionsRequest {
} }
fn metadata(&self) -> &Option<HashMap<String, Value>> { fn metadata(&self) -> &Option<HashMap<String, Value>> {
return &self.metadata; &self.metadata
} }
fn remove_metadata_key(&mut self, key: &str) -> bool { fn remove_metadata_key(&mut self, key: &str) -> bool {

View file

@ -1,7 +1,7 @@
use std::collections::HashMap; use crate::providers::request::{ProviderRequest, ProviderRequestError};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_with::skip_serializing_none; use serde_with::skip_serializing_none;
use crate::providers::request::{ProviderRequest, ProviderRequestError}; use std::collections::HashMap;
impl TryFrom<&[u8]> for ResponsesAPIRequest { impl TryFrom<&[u8]> for ResponsesAPIRequest {
type Error = serde_json::Error; type Error = serde_json::Error;
@ -172,18 +172,14 @@ pub enum MessageRole {
#[serde(tag = "type", rename_all = "snake_case")] #[serde(tag = "type", rename_all = "snake_case")]
pub enum InputContent { pub enum InputContent {
/// Text input /// Text input
InputText { InputText { text: String },
text: String,
},
/// Image input via URL /// Image input via URL
InputImage { InputImage {
image_url: String, image_url: String,
detail: Option<String>, detail: Option<String>,
}, },
/// File input via URL /// File input via URL
InputFile { InputFile { file_url: String },
file_url: String,
},
/// Audio input /// Audio input
InputAudio { InputAudio {
data: Option<String>, data: Option<String>,
@ -222,9 +218,7 @@ pub struct TextConfig {
pub enum TextFormat { pub enum TextFormat {
Text, Text,
JsonObject, JsonObject,
JsonSchema { JsonSchema { json_schema: serde_json::Value },
json_schema: serde_json::Value,
},
} }
/// Reasoning effort levels /// Reasoning effort levels
@ -608,9 +602,7 @@ pub enum OutputContent {
transcript: Option<String>, transcript: Option<String>,
}, },
/// Refusal output /// Refusal output
Refusal { Refusal { refusal: String },
refusal: String,
},
} }
/// Annotations for output text /// Annotations for output text
@ -663,13 +655,9 @@ pub struct FileSearchResult {
#[serde(tag = "type", rename_all = "snake_case")] #[serde(tag = "type", rename_all = "snake_case")]
pub enum CodeInterpreterOutput { pub enum CodeInterpreterOutput {
/// Text output /// Text output
Text { Text { text: String },
text: String,
},
/// Image output /// Image output
Image { Image { image: String },
image: String,
},
} }
/// Response usage statistics /// Response usage statistics
@ -951,9 +939,7 @@ pub enum ResponsesAPIStreamEvent {
}, },
/// Done event (end of stream) /// Done event (end of stream)
Done { Done { sequence_number: i32 },
sequence_number: i32,
},
} }
// ============================================================================ // ============================================================================
@ -1052,12 +1038,19 @@ impl ProviderRequest for ResponsesAPIRequest {
MessageContent::Text(text) => text.clone(), MessageContent::Text(text) => text.clone(),
MessageContent::Items(content_items) => { MessageContent::Items(content_items) => {
content_items.iter().fold(String::new(), |acc, content| { content_items.iter().fold(String::new(), |acc, content| {
acc + " " + &match content { acc + " "
InputContent::InputText { text } => text.clone(), + &match content {
InputContent::InputImage { .. } => "[Image]".to_string(), InputContent::InputText { text } => text.clone(),
InputContent::InputFile { .. } => "[File]".to_string(), InputContent::InputImage { .. } => {
InputContent::InputAudio { .. } => "[Audio]".to_string(), "[Image]".to_string()
} }
InputContent::InputFile { .. } => {
"[File]".to_string()
}
InputContent::InputAudio { .. } => {
"[Audio]".to_string()
}
}
}) })
} }
}; };
@ -1082,11 +1075,9 @@ impl ProviderRequest for ResponsesAPIRequest {
match &msg.content { match &msg.content {
MessageContent::Text(text) => Some(text.clone()), MessageContent::Text(text) => Some(text.clone()),
MessageContent::Items(content_items) => { MessageContent::Items(content_items) => {
content_items.iter().find_map(|content| { content_items.iter().find_map(|content| match content {
match content { InputContent::InputText { text } => Some(text.clone()),
InputContent::InputText { text } => Some(text.clone()), _ => None,
_ => None,
}
}) })
} }
} }
@ -1176,9 +1167,12 @@ impl ProviderRequest for ResponsesAPIRequest {
// Extract text from message content // Extract text from message content
let content = match &msg.content { let content = match &msg.content {
crate::apis::openai_responses::MessageContent::Text(text) => text.clone(), crate::apis::openai_responses::MessageContent::Text(text) => {
text.clone()
}
crate::apis::openai_responses::MessageContent::Items(items) => { crate::apis::openai_responses::MessageContent::Items(items) => {
items.iter() items
.iter()
.filter_map(|c| { .filter_map(|c| {
if let InputContent::InputText { text } = c { if let InputContent::InputText { text } = c {
Some(text.clone()) Some(text.clone())
@ -1214,7 +1208,8 @@ impl ProviderRequest for ResponsesAPIRequest {
fn set_messages(&mut self, messages: &[crate::apis::openai::Message]) { fn set_messages(&mut self, messages: &[crate::apis::openai::Message]) {
// For ResponsesAPI, we need to convert messages back to input format // For ResponsesAPI, we need to convert messages back to input format
// Extract system messages as instructions // Extract system messages as instructions
let system_text = messages.iter() let system_text = messages
.iter()
.filter(|msg| msg.role == crate::apis::openai::Role::System) .filter(|msg| msg.role == crate::apis::openai::Role::System)
.filter_map(|msg| { .filter_map(|msg| {
if let crate::apis::openai::MessageContent::Text(text) = &msg.content { if let crate::apis::openai::MessageContent::Text(text) = &msg.content {
@ -1233,23 +1228,27 @@ impl ProviderRequest for ResponsesAPIRequest {
// Convert user/assistant messages to InputParam // Convert user/assistant messages to InputParam
// For simplicity, we'll use the last user message as the input // For simplicity, we'll use the last user message as the input
// or combine all non-system messages // or combine all non-system messages
let input_messages: Vec<_> = messages.iter() let input_messages: Vec<_> = messages
.iter()
.filter(|msg| msg.role != crate::apis::openai::Role::System) .filter(|msg| msg.role != crate::apis::openai::Role::System)
.collect(); .collect();
if !input_messages.is_empty() { if !input_messages.is_empty() {
// If there's only one message, use Text format // If there's only one message, use Text format
if input_messages.len() == 1 { if input_messages.len() == 1 {
if let crate::apis::openai::MessageContent::Text(text) = &input_messages[0].content { if let crate::apis::openai::MessageContent::Text(text) = &input_messages[0].content
{
self.input = crate::apis::openai_responses::InputParam::Text(text.clone()); self.input = crate::apis::openai_responses::InputParam::Text(text.clone());
} }
} else { } else {
// Multiple messages - combine them as text for now // Multiple messages - combine them as text for now
// A more sophisticated approach would use InputParam::Items // A more sophisticated approach would use InputParam::Items
let combined_text = input_messages.iter() let combined_text = input_messages
.iter()
.filter_map(|msg| { .filter_map(|msg| {
if let crate::apis::openai::MessageContent::Text(text) = &msg.content { if let crate::apis::openai::MessageContent::Text(text) = &msg.content {
Some(format!("{}: {}", Some(format!(
"{}: {}",
match msg.role { match msg.role {
crate::apis::openai::Role::User => "User", crate::apis::openai::Role::User => "User",
crate::apis::openai::Role::Assistant => "Assistant", crate::apis::openai::Role::Assistant => "Assistant",
@ -1274,10 +1273,10 @@ impl ProviderRequest for ResponsesAPIRequest {
// Into<String> Implementation for SSE Formatting // Into<String> Implementation for SSE Formatting
// ============================================================================ // ============================================================================
impl Into<String> for ResponsesAPIStreamEvent { impl From<ResponsesAPIStreamEvent> for String {
fn into(self) -> String { fn from(val: ResponsesAPIStreamEvent) -> Self {
let transformed_json = serde_json::to_string(&self).unwrap_or_default(); let transformed_json = serde_json::to_string(&val).unwrap_or_default();
let event_type = match &self { let event_type = match &val {
ResponsesAPIStreamEvent::ResponseCreated { .. } => "response.created", ResponsesAPIStreamEvent::ResponseCreated { .. } => "response.created",
ResponsesAPIStreamEvent::ResponseInProgress { .. } => "response.in_progress", ResponsesAPIStreamEvent::ResponseInProgress { .. } => "response.in_progress",
ResponsesAPIStreamEvent::ResponseCompleted { .. } => "response.completed", ResponsesAPIStreamEvent::ResponseCompleted { .. } => "response.completed",
@ -1365,10 +1364,10 @@ impl crate::providers::streaming_response::ProviderStreamResponse for ResponsesA
fn role(&self) -> Option<&str> { fn role(&self) -> Option<&str> {
match self { match self {
ResponsesAPIStreamEvent::ResponseOutputItemDone { item, .. } => match item { ResponsesAPIStreamEvent::ResponseOutputItemDone {
OutputItem::Message { role, .. } => Some(role.as_str()), item: OutputItem::Message { role, .. },
_ => None, ..
}, } => Some(role.as_str()),
_ => None, _ => None,
} }
} }

View file

@ -34,10 +34,7 @@ where
} }
pub fn decode_frame(&mut self) -> Option<DecodedFrame> { pub fn decode_frame(&mut self) -> Option<DecodedFrame> {
match self.decoder.decode_frame(&mut self.buffer) { self.decoder.decode_frame(&mut self.buffer).ok()
Ok(frame) => Some(frame),
Err(_e) => None, // Fatal decode error
}
} }
pub fn buffer_mut(&mut self) -> &mut B { pub fn buffer_mut(&mut self) -> &mut B {

View file

@ -1,5 +1,5 @@
use crate::apis::streaming_shapes::sse::{SseEvent, SseStreamBufferTrait};
use crate::apis::anthropic::MessagesStreamEvent; use crate::apis::anthropic::MessagesStreamEvent;
use crate::apis::streaming_shapes::sse::{SseEvent, SseStreamBufferTrait};
use crate::providers::streaming_response::ProviderStreamResponseType; use crate::providers::streaming_response::ProviderStreamResponseType;
use std::collections::HashSet; use std::collections::HashSet;
@ -31,6 +31,12 @@ pub struct AnthropicMessagesStreamBuffer {
model: Option<String>, model: Option<String>,
} }
impl Default for AnthropicMessagesStreamBuffer {
fn default() -> Self {
Self::new()
}
}
impl AnthropicMessagesStreamBuffer { impl AnthropicMessagesStreamBuffer {
pub fn new() -> Self { pub fn new() -> Self {
Self { Self {
@ -154,7 +160,8 @@ impl SseStreamBufferTrait for AnthropicMessagesStreamBuffer {
// Inject message_start if needed // Inject message_start if needed
if !self.message_started { if !self.message_started {
let model = self.model.as_deref().unwrap_or("unknown"); let model = self.model.as_deref().unwrap_or("unknown");
let message_start = AnthropicMessagesStreamBuffer::create_message_start_event(model); let message_start =
AnthropicMessagesStreamBuffer::create_message_start_event(model);
self.buffered_events.push(message_start); self.buffered_events.push(message_start);
self.message_started = true; self.message_started = true;
} }
@ -169,7 +176,8 @@ impl SseStreamBufferTrait for AnthropicMessagesStreamBuffer {
// Inject message_start if needed // Inject message_start if needed
if !self.message_started { if !self.message_started {
let model = self.model.as_deref().unwrap_or("unknown"); let model = self.model.as_deref().unwrap_or("unknown");
let message_start = AnthropicMessagesStreamBuffer::create_message_start_event(model); let message_start =
AnthropicMessagesStreamBuffer::create_message_start_event(model);
self.buffered_events.push(message_start); self.buffered_events.push(message_start);
self.message_started = true; self.message_started = true;
} }
@ -177,7 +185,8 @@ impl SseStreamBufferTrait for AnthropicMessagesStreamBuffer {
// Check if ContentBlockStart was sent for this index // Check if ContentBlockStart was sent for this index
if !self.has_content_block_start_been_sent(index) { if !self.has_content_block_start_been_sent(index) {
// Inject ContentBlockStart before delta // Inject ContentBlockStart before delta
let content_block_start = AnthropicMessagesStreamBuffer::create_content_block_start_event(); let content_block_start =
AnthropicMessagesStreamBuffer::create_content_block_start_event();
self.buffered_events.push(content_block_start); self.buffered_events.push(content_block_start);
self.set_content_block_start_sent(index); self.set_content_block_start_sent(index);
self.needs_content_block_stop = true; self.needs_content_block_stop = true;
@ -189,7 +198,8 @@ impl SseStreamBufferTrait for AnthropicMessagesStreamBuffer {
MessagesStreamEvent::MessageDelta { usage, .. } => { MessagesStreamEvent::MessageDelta { usage, .. } => {
// Inject ContentBlockStop before message_delta // Inject ContentBlockStop before message_delta
if self.needs_content_block_stop { if self.needs_content_block_stop {
let content_block_stop = AnthropicMessagesStreamBuffer::create_content_block_stop_event(); let content_block_stop =
AnthropicMessagesStreamBuffer::create_content_block_stop_event();
self.buffered_events.push(content_block_stop); self.buffered_events.push(content_block_stop);
self.needs_content_block_stop = false; self.needs_content_block_stop = false;
} }
@ -199,10 +209,10 @@ impl SseStreamBufferTrait for AnthropicMessagesStreamBuffer {
if let Some(last_event) = self.buffered_events.last_mut() { if let Some(last_event) = self.buffered_events.last_mut() {
if let Some(ProviderStreamResponseType::MessagesStreamEvent( if let Some(ProviderStreamResponseType::MessagesStreamEvent(
MessagesStreamEvent::MessageDelta { MessagesStreamEvent::MessageDelta {
usage: last_usage, usage: last_usage, ..
.. },
} )) = &mut last_event.provider_stream_response
)) = &mut last_event.provider_stream_response { {
// Merge: take stop_reason from first, usage from second (if non-zero) // Merge: take stop_reason from first, usage from second (if non-zero)
if usage.input_tokens > 0 || usage.output_tokens > 0 { if usage.input_tokens > 0 || usage.output_tokens > 0 {
*last_usage = usage.clone(); *last_usage = usage.clone();
@ -243,7 +253,7 @@ impl SseStreamBufferTrait for AnthropicMessagesStreamBuffer {
} }
} }
fn into_bytes(&mut self) -> Vec<u8> { fn to_bytes(&mut self) -> Vec<u8> {
// Convert all accumulated events to bytes and clear buffer // Convert all accumulated events to bytes and clear buffer
// NOTE: We do NOT inject ContentBlockStop here because it's injected when we see MessageDelta // NOTE: We do NOT inject ContentBlockStop here because it's injected when we see MessageDelta
// or MessageStop. Injecting it here causes premature ContentBlockStop in the middle of streaming. // or MessageStop. Injecting it here causes premature ContentBlockStop in the middle of streaming.
@ -276,10 +286,10 @@ impl SseStreamBufferTrait for AnthropicMessagesStreamBuffer {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::clients::{SupportedAPIsFromClient, SupportedUpstreamAPIs};
use crate::apis::anthropic::AnthropicApi; use crate::apis::anthropic::AnthropicApi;
use crate::apis::openai::OpenAIApi; use crate::apis::openai::OpenAIApi;
use crate::apis::streaming_shapes::sse::SseStreamIter; use crate::apis::streaming_shapes::sse::SseStreamIter;
use crate::clients::{SupportedAPIsFromClient, SupportedUpstreamAPIs};
#[test] #[test]
fn test_openai_to_anthropic_complete_transformation() { fn test_openai_to_anthropic_complete_transformation() {
@ -308,11 +318,12 @@ data: [DONE]"#;
let mut buffer = AnthropicMessagesStreamBuffer::new(); let mut buffer = AnthropicMessagesStreamBuffer::new();
for raw_event in stream_iter { for raw_event in stream_iter {
let transformed_event = SseEvent::try_from((raw_event, &client_api, &upstream_api)).unwrap(); let transformed_event =
SseEvent::try_from((raw_event, &client_api, &upstream_api)).unwrap();
buffer.add_transformed_event(transformed_event); buffer.add_transformed_event(transformed_event);
} }
let output_bytes = buffer.into_bytes(); let output_bytes = buffer.to_bytes();
let output = String::from_utf8_lossy(&output_bytes); let output = String::from_utf8_lossy(&output_bytes);
println!("\nTRANSFORMED OUTPUT (Anthropic Messages API):"); println!("\nTRANSFORMED OUTPUT (Anthropic Messages API):");
@ -321,25 +332,54 @@ data: [DONE]"#;
// Assertions // Assertions
assert!(!output_bytes.is_empty(), "Should have output"); assert!(!output_bytes.is_empty(), "Should have output");
assert!(output.contains("event: message_start"), "Should have message_start"); assert!(
assert!(output.contains("event: content_block_start"), "Should have content_block_start (injected)"); output.contains("event: message_start"),
"Should have message_start"
);
assert!(
output.contains("event: content_block_start"),
"Should have content_block_start (injected)"
);
let delta_count = output.matches("event: content_block_delta").count(); let delta_count = output.matches("event: content_block_delta").count();
assert_eq!(delta_count, 2, "Should have exactly 2 content_block_delta events"); assert_eq!(
delta_count, 2,
"Should have exactly 2 content_block_delta events"
);
// Verify both pieces of content are present // Verify both pieces of content are present
assert!(output.contains("\"text\":\"Hello\""), "Should have first content delta 'Hello'"); assert!(
assert!(output.contains("\"text\":\" world\""), "Should have second content delta ' world'"); output.contains("\"text\":\"Hello\""),
"Should have first content delta 'Hello'"
);
assert!(
output.contains("\"text\":\" world\""),
"Should have second content delta ' world'"
);
assert!(output.contains("event: content_block_stop"), "Should have content_block_stop (injected)"); assert!(
assert!(output.contains("event: message_delta"), "Should have message_delta"); output.contains("event: content_block_stop"),
assert!(output.contains("event: message_stop"), "Should have message_stop"); "Should have content_block_stop (injected)"
);
assert!(
output.contains("event: message_delta"),
"Should have message_delta"
);
assert!(
output.contains("event: message_stop"),
"Should have message_stop"
);
println!("\nVALIDATION SUMMARY:"); println!("\nVALIDATION SUMMARY:");
println!("{}", "-".repeat(80)); println!("{}", "-".repeat(80));
println!("✓ Complete transformation: OpenAI ChatCompletions → Anthropic Messages API"); println!("✓ Complete transformation: OpenAI ChatCompletions → Anthropic Messages API");
println!("✓ Injected lifecycle events: message_start, content_block_start, content_block_stop"); println!(
println!("✓ Content deltas: {} events (BOTH 'Hello' and ' world' preserved!)", delta_count); "✓ Injected lifecycle events: message_start, content_block_start, content_block_stop"
);
println!(
"✓ Content deltas: {} events (BOTH 'Hello' and ' world' preserved!)",
delta_count
);
println!("✓ Complete stream with message_stop"); println!("✓ Complete stream with message_stop");
println!("✓ Proper Anthropic protocol sequencing\n"); println!("✓ Proper Anthropic protocol sequencing\n");
} }
@ -369,11 +409,12 @@ data: {"id":"chatcmpl-456","object":"chat.completion.chunk","created":1234567890
let mut buffer = AnthropicMessagesStreamBuffer::new(); let mut buffer = AnthropicMessagesStreamBuffer::new();
for raw_event in stream_iter { for raw_event in stream_iter {
let transformed_event = SseEvent::try_from((raw_event, &client_api, &upstream_api)).unwrap(); let transformed_event =
SseEvent::try_from((raw_event, &client_api, &upstream_api)).unwrap();
buffer.add_transformed_event(transformed_event); buffer.add_transformed_event(transformed_event);
} }
let output_bytes = buffer.into_bytes(); let output_bytes = buffer.to_bytes();
let output = String::from_utf8_lossy(&output_bytes); let output = String::from_utf8_lossy(&output_bytes);
println!("\nTRANSFORMED OUTPUT (Anthropic Messages API):"); println!("\nTRANSFORMED OUTPUT (Anthropic Messages API):");
@ -382,31 +423,61 @@ data: {"id":"chatcmpl-456","object":"chat.completion.chunk","created":1234567890
// Assertions // Assertions
assert!(!output_bytes.is_empty(), "Should have output"); assert!(!output_bytes.is_empty(), "Should have output");
assert!(output.contains("event: message_start"), "Should have message_start"); assert!(
assert!(output.contains("event: content_block_start"), "Should have content_block_start (injected)"); output.contains("event: message_start"),
"Should have message_start"
);
assert!(
output.contains("event: content_block_start"),
"Should have content_block_start (injected)"
);
let delta_count = output.matches("event: content_block_delta").count(); let delta_count = output.matches("event: content_block_delta").count();
assert_eq!(delta_count, 3, "Should have exactly 3 content_block_delta events"); assert_eq!(
delta_count, 3,
"Should have exactly 3 content_block_delta events"
);
// Verify all three pieces of content are present // Verify all three pieces of content are present
assert!(output.contains("\"text\":\"The weather\""), "Should have first content delta"); assert!(
assert!(output.contains("\"text\":\" in San Francisco\""), "Should have second content delta"); output.contains("\"text\":\"The weather\""),
assert!(output.contains("\"text\":\" is\""), "Should have third content delta"); "Should have first content delta"
);
assert!(
output.contains("\"text\":\" in San Francisco\""),
"Should have second content delta"
);
assert!(
output.contains("\"text\":\" is\""),
"Should have third content delta"
);
// For partial streams (no finish_reason, no [DONE]), we do NOT inject content_block_stop // For partial streams (no finish_reason, no [DONE]), we do NOT inject content_block_stop
// because the stream may continue. This is correct behavior - only inject lifecycle events // because the stream may continue. This is correct behavior - only inject lifecycle events
// when we have explicit signals from upstream (finish_reason, [DONE], etc.) // when we have explicit signals from upstream (finish_reason, [DONE], etc.)
assert!(!output.contains("event: content_block_stop"), "Should NOT have content_block_stop for partial stream"); assert!(
!output.contains("event: content_block_stop"),
"Should NOT have content_block_stop for partial stream"
);
// Should NOT have completion events // Should NOT have completion events
assert!(!output.contains("event: message_delta"), "Should NOT have message_delta"); assert!(
assert!(!output.contains("event: message_stop"), "Should NOT have message_stop"); !output.contains("event: message_delta"),
"Should NOT have message_delta"
);
assert!(
!output.contains("event: message_stop"),
"Should NOT have message_stop"
);
println!("\nVALIDATION SUMMARY:"); println!("\nVALIDATION SUMMARY:");
println!("{}", "-".repeat(80)); println!("{}", "-".repeat(80));
println!("✓ Partial transformation: OpenAI → Anthropic (stream interrupted)"); println!("✓ Partial transformation: OpenAI → Anthropic (stream interrupted)");
println!("✓ Injected: message_start, content_block_start at beginning"); println!("✓ Injected: message_start, content_block_start at beginning");
println!("✓ Incremental deltas: {} events (ALL content preserved!)", delta_count); println!(
"✓ Incremental deltas: {} events (ALL content preserved!)",
delta_count
);
println!("✓ NO completion events (partial stream, no [DONE])"); println!("✓ NO completion events (partial stream, no [DONE])");
println!("✓ Buffer maintains Anthropic protocol for active streams\n"); println!("✓ Buffer maintains Anthropic protocol for active streams\n");
} }
@ -452,11 +523,12 @@ data: [DONE]"#;
let mut buffer = AnthropicMessagesStreamBuffer::new(); let mut buffer = AnthropicMessagesStreamBuffer::new();
for raw_event in stream_iter { for raw_event in stream_iter {
let transformed_event = SseEvent::try_from((raw_event, &client_api, &upstream_api)).unwrap(); let transformed_event =
SseEvent::try_from((raw_event, &client_api, &upstream_api)).unwrap();
buffer.add_transformed_event(transformed_event); buffer.add_transformed_event(transformed_event);
} }
let output_bytes = buffer.into_bytes(); let output_bytes = buffer.to_bytes();
let output = String::from_utf8_lossy(&output_bytes); let output = String::from_utf8_lossy(&output_bytes);
println!("\nTRANSFORMED OUTPUT (Anthropic Messages API):"); println!("\nTRANSFORMED OUTPUT (Anthropic Messages API):");
@ -467,32 +539,71 @@ data: [DONE]"#;
assert!(!output_bytes.is_empty(), "Should have output"); assert!(!output_bytes.is_empty(), "Should have output");
// Should have lifecycle events (injected by buffer) // Should have lifecycle events (injected by buffer)
assert!(output.contains("event: message_start"), "Should have message_start (injected)"); assert!(
assert!(output.contains("event: content_block_start"), "Should have content_block_start"); output.contains("event: message_start"),
assert!(output.contains("event: content_block_stop"), "Should have content_block_stop (injected)"); "Should have message_start (injected)"
assert!(output.contains("event: message_delta"), "Should have message_delta"); );
assert!(output.contains("event: message_stop"), "Should have message_stop"); assert!(
output.contains("event: content_block_start"),
"Should have content_block_start"
);
assert!(
output.contains("event: content_block_stop"),
"Should have content_block_stop (injected)"
);
assert!(
output.contains("event: message_delta"),
"Should have message_delta"
);
assert!(
output.contains("event: message_stop"),
"Should have message_stop"
);
// Should have tool_use content block // Should have tool_use content block
assert!(output.contains("\"type\":\"tool_use\""), "Should have tool_use type"); assert!(
assert!(output.contains("\"name\":\"get_weather\""), "Should have correct function name"); output.contains("\"type\":\"tool_use\""),
assert!(output.contains("\"id\":\"call_2Uzw0AEZQeOex2CP2TKjcLKc\""), "Should have correct tool call ID"); "Should have tool_use type"
);
assert!(
output.contains("\"name\":\"get_weather\""),
"Should have correct function name"
);
assert!(
output.contains("\"id\":\"call_2Uzw0AEZQeOex2CP2TKjcLKc\""),
"Should have correct tool call ID"
);
// Count input_json_delta events - should match the number of argument chunks // Count input_json_delta events - should match the number of argument chunks
let delta_count = output.matches("event: content_block_delta").count(); let delta_count = output.matches("event: content_block_delta").count();
assert!(delta_count >= 8, "Should have at least 8 input_json_delta events"); assert!(
delta_count >= 8,
"Should have at least 8 input_json_delta events"
);
// Verify argument deltas are present // Verify argument deltas are present
assert!(output.contains("\"type\":\"input_json_delta\""), "Should have input_json_delta type"); assert!(
assert!(output.contains("\"partial_json\":"), "Should have partial_json field"); output.contains("\"type\":\"input_json_delta\""),
"Should have input_json_delta type"
);
assert!(
output.contains("\"partial_json\":"),
"Should have partial_json field"
);
// Verify the accumulated arguments contain the location // Verify the accumulated arguments contain the location
assert!(output.contains("San"), "Arguments should contain 'San'"); assert!(output.contains("San"), "Arguments should contain 'San'");
assert!(output.contains("Francisco"), "Arguments should contain 'Francisco'"); assert!(
output.contains("Francisco"),
"Arguments should contain 'Francisco'"
);
assert!(output.contains("CA"), "Arguments should contain 'CA'"); assert!(output.contains("CA"), "Arguments should contain 'CA'");
// Verify stop reason is tool_use // Verify stop reason is tool_use
assert!(output.contains("\"stop_reason\":\"tool_use\""), "Should have stop_reason as tool_use"); assert!(
output.contains("\"stop_reason\":\"tool_use\""),
"Should have stop_reason as tool_use"
);
println!("\nVALIDATION SUMMARY:"); println!("\nVALIDATION SUMMARY:");
println!("{}", "-".repeat(80)); println!("{}", "-".repeat(80));

View file

@ -6,6 +6,12 @@ pub struct OpenAIChatCompletionsStreamBuffer {
buffered_events: Vec<SseEvent>, buffered_events: Vec<SseEvent>,
} }
impl Default for OpenAIChatCompletionsStreamBuffer {
fn default() -> Self {
Self::new()
}
}
impl OpenAIChatCompletionsStreamBuffer { impl OpenAIChatCompletionsStreamBuffer {
pub fn new() -> Self { pub fn new() -> Self {
Self { Self {
@ -26,7 +32,7 @@ impl SseStreamBufferTrait for OpenAIChatCompletionsStreamBuffer {
self.buffered_events.push(event); self.buffered_events.push(event);
} }
fn into_bytes(&mut self) -> Vec<u8> { fn to_bytes(&mut self) -> Vec<u8> {
// No finalization needed for OpenAI Chat Completions // No finalization needed for OpenAI Chat Completions
// The [DONE] marker is already handled by the transformation layer // The [DONE] marker is already handled by the transformation layer
let mut buffer = Vec::new(); let mut buffer = Vec::new();

View file

@ -1,7 +1,7 @@
pub mod sse;
pub mod sse_chunk_processor;
pub mod amazon_bedrock_binary_frame; pub mod amazon_bedrock_binary_frame;
pub mod anthropic_streaming_buffer; pub mod anthropic_streaming_buffer;
pub mod chat_completions_streaming_buffer; pub mod chat_completions_streaming_buffer;
pub mod passthrough_streaming_buffer; pub mod passthrough_streaming_buffer;
pub mod responses_api_streaming_buffer; pub mod responses_api_streaming_buffer;
pub mod sse;
pub mod sse_chunk_processor;

View file

@ -6,6 +6,12 @@ pub struct PassthroughStreamBuffer {
buffered_events: Vec<SseEvent>, buffered_events: Vec<SseEvent>,
} }
impl Default for PassthroughStreamBuffer {
fn default() -> Self {
Self::new()
}
}
impl PassthroughStreamBuffer { impl PassthroughStreamBuffer {
pub fn new() -> Self { pub fn new() -> Self {
Self { Self {
@ -30,7 +36,7 @@ impl SseStreamBufferTrait for PassthroughStreamBuffer {
self.buffered_events.push(event); self.buffered_events.push(event);
} }
fn into_bytes(&mut self) -> Vec<u8> { fn to_bytes(&mut self) -> Vec<u8> {
// No finalization needed for passthrough - just convert accumulated events to bytes // No finalization needed for passthrough - just convert accumulated events to bytes
let mut buffer = Vec::new(); let mut buffer = Vec::new();
for event in self.buffered_events.drain(..) { for event in self.buffered_events.drain(..) {
@ -44,7 +50,7 @@ impl SseStreamBufferTrait for PassthroughStreamBuffer {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use crate::apis::streaming_shapes::passthrough_streaming_buffer::PassthroughStreamBuffer; use crate::apis::streaming_shapes::passthrough_streaming_buffer::PassthroughStreamBuffer;
use crate::apis::streaming_shapes::sse::{SseStreamIter, SseStreamBufferTrait}; use crate::apis::streaming_shapes::sse::{SseStreamBufferTrait, SseStreamIter};
#[test] #[test]
fn test_chat_completions_passthrough_buffer() { fn test_chat_completions_passthrough_buffer() {
@ -73,7 +79,7 @@ mod tests {
buffer.add_transformed_event(event); buffer.add_transformed_event(event);
} }
let output_bytes = buffer.into_bytes(); let output_bytes = buffer.to_bytes();
let output = String::from_utf8_lossy(&output_bytes); let output = String::from_utf8_lossy(&output_bytes);
println!("\nTRANSFORMED OUTPUT (ChatCompletions - Passthrough):"); println!("\nTRANSFORMED OUTPUT (ChatCompletions - Passthrough):");
@ -84,7 +90,11 @@ mod tests {
assert!(!output_bytes.is_empty()); assert!(!output_bytes.is_empty());
assert!(output.contains("chatcmpl-123")); assert!(output.contains("chatcmpl-123"));
assert!(output.contains("[DONE]")); assert!(output.contains("[DONE]"));
assert_eq!(raw_input.trim(), output.trim(), "Passthrough should preserve input"); assert_eq!(
raw_input.trim(),
output.trim(),
"Passthrough should preserve input"
);
println!("\nVALIDATION SUMMARY:"); println!("\nVALIDATION SUMMARY:");
println!("{}", "-".repeat(80)); println!("{}", "-".repeat(80));

View file

@ -1,10 +1,10 @@
use std::collections::HashMap;
use log::debug;
use crate::apis::openai_responses::{ use crate::apis::openai_responses::{
ResponsesAPIStreamEvent, ResponsesAPIResponse, OutputItem, OutputItemStatus, OutputItem, OutputItemStatus, Reasoning, ResponseStatus, ResponsesAPIResponse,
ResponseStatus, TextConfig, TextFormat, Reasoning, ResponsesAPIStreamEvent, TextConfig, TextFormat,
}; };
use crate::apis::streaming_shapes::sse::{SseEvent, SseStreamBufferTrait}; use crate::apis::streaming_shapes::sse::{SseEvent, SseStreamBufferTrait};
use log::debug;
use std::collections::HashMap;
/// Helper to convert ResponseAPIStreamEvent to SseEvent /// Helper to convert ResponseAPIStreamEvent to SseEvent
fn event_to_sse(event: ResponsesAPIStreamEvent) -> SseEvent { fn event_to_sse(event: ResponsesAPIStreamEvent) -> SseEvent {
@ -16,10 +16,17 @@ fn event_to_sse(event: ResponsesAPIStreamEvent) -> SseEvent {
ResponsesAPIStreamEvent::ResponseOutputItemDone { .. } => "response.output_item.done", ResponsesAPIStreamEvent::ResponseOutputItemDone { .. } => "response.output_item.done",
ResponsesAPIStreamEvent::ResponseOutputTextDelta { .. } => "response.output_text.delta", ResponsesAPIStreamEvent::ResponseOutputTextDelta { .. } => "response.output_text.delta",
ResponsesAPIStreamEvent::ResponseOutputTextDone { .. } => "response.output_text.done", ResponsesAPIStreamEvent::ResponseOutputTextDone { .. } => "response.output_text.done",
ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDelta { .. } => "response.function_call_arguments.delta", ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDelta { .. } => {
ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDone { .. } => "response.function_call_arguments.done", "response.function_call_arguments.delta"
}
ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDone { .. } => {
"response.function_call_arguments.done"
}
unknown => { unknown => {
debug!("Unknown ResponsesAPIStreamEvent type encountered: {:?}", unknown); debug!(
"Unknown ResponsesAPIStreamEvent type encountered: {:?}",
unknown
);
"unknown" "unknown"
} }
}; };
@ -85,6 +92,12 @@ pub struct ResponsesAPIStreamBuffer {
buffered_events: Vec<SseEvent>, buffered_events: Vec<SseEvent>,
} }
impl Default for ResponsesAPIStreamBuffer {
fn default() -> Self {
Self::new()
}
}
impl ResponsesAPIStreamBuffer { impl ResponsesAPIStreamBuffer {
pub fn new() -> Self { pub fn new() -> Self {
Self { Self {
@ -112,7 +125,11 @@ impl ResponsesAPIStreamBuffer {
} }
fn generate_item_id(prefix: &str) -> String { fn generate_item_id(prefix: &str) -> String {
format!("{}_{}", prefix, uuid::Uuid::new_v4().to_string().replace("-", "")) format!(
"{}_{}",
prefix,
uuid::Uuid::new_v4().to_string().replace("-", "")
)
} }
fn get_or_create_item_id(&mut self, output_index: i32, prefix: &str) -> String { fn get_or_create_item_id(&mut self, output_index: i32, prefix: &str) -> String {
@ -160,7 +177,13 @@ impl ResponsesAPIStreamBuffer {
} }
/// Create output_item.added event for tool call /// Create output_item.added event for tool call
fn create_tool_call_added_event(&mut self, output_index: i32, item_id: &str, call_id: &str, name: &str) -> SseEvent { fn create_tool_call_added_event(
&mut self,
output_index: i32,
item_id: &str,
call_id: &str,
name: &str,
) -> SseEvent {
let event = ResponsesAPIStreamEvent::ResponseOutputItemAdded { let event = ResponsesAPIStreamEvent::ResponseOutputItemAdded {
output_index, output_index,
item: OutputItem::FunctionCall { item: OutputItem::FunctionCall {
@ -237,9 +260,15 @@ impl ResponsesAPIStreamBuffer {
// Emit done events for all accumulated content // Emit done events for all accumulated content
// Text content done events // Text content done events
let text_items: Vec<_> = self.text_content.iter().map(|(id, content)| (id.clone(), content.clone())).collect(); let text_items: Vec<_> = self
.text_content
.iter()
.map(|(id, content)| (id.clone(), content.clone()))
.collect();
for (item_id, content) in text_items { for (item_id, content) in text_items {
let output_index = self.output_items_added.iter() let output_index = self
.output_items_added
.iter()
.find(|(_, id)| **id == item_id) .find(|(_, id)| **id == item_id)
.map(|(idx, _)| *idx) .map(|(idx, _)| *idx)
.unwrap_or(0); .unwrap_or(0);
@ -270,9 +299,15 @@ impl ResponsesAPIStreamBuffer {
} }
// Function call done events // Function call done events
let func_items: Vec<_> = self.function_arguments.iter().map(|(id, args)| (id.clone(), args.clone())).collect(); let func_items: Vec<_> = self
.function_arguments
.iter()
.map(|(id, args)| (id.clone(), args.clone()))
.collect();
for (item_id, arguments) in func_items { for (item_id, arguments) in func_items {
let output_index = self.output_items_added.iter() let output_index = self
.output_items_added
.iter()
.find(|(_, id)| **id == item_id) .find(|(_, id)| **id == item_id)
.map(|(idx, _)| *idx) .map(|(idx, _)| *idx)
.unwrap_or(0); .unwrap_or(0);
@ -286,9 +321,16 @@ impl ResponsesAPIStreamBuffer {
}; };
events.push(event_to_sse(args_done_event)); events.push(event_to_sse(args_done_event));
let (call_id, name) = self.tool_call_metadata.get(&output_index) let (call_id, name) = self
.tool_call_metadata
.get(&output_index)
.cloned() .cloned()
.unwrap_or_else(|| (format!("call_{}", uuid::Uuid::new_v4()), "unknown".to_string())); .unwrap_or_else(|| {
(
format!("call_{}", uuid::Uuid::new_v4()),
"unknown".to_string(),
)
});
let seq2 = self.next_sequence_number(); let seq2 = self.next_sequence_number();
let item_done_event = ResponsesAPIStreamEvent::ResponseOutputItemDone { let item_done_event = ResponsesAPIStreamEvent::ResponseOutputItemDone {
@ -315,9 +357,16 @@ impl ResponsesAPIStreamBuffer {
if let Some(item_id) = self.output_items_added.get(&output_index) { if let Some(item_id) = self.output_items_added.get(&output_index) {
// Check if this is a function call // Check if this is a function call
if let Some(arguments) = self.function_arguments.get(item_id) { if let Some(arguments) = self.function_arguments.get(item_id) {
let (call_id, name) = self.tool_call_metadata.get(&output_index) let (call_id, name) = self
.tool_call_metadata
.get(&output_index)
.cloned() .cloned()
.unwrap_or_else(|| (format!("call_{}", uuid::Uuid::new_v4()), "unknown".to_string())); .unwrap_or_else(|| {
(
format!("call_{}", uuid::Uuid::new_v4()),
"unknown".to_string(),
)
});
output_items.push(OutputItem::FunctionCall { output_items.push(OutputItem::FunctionCall {
id: item_id.clone(), id: item_id.clone(),
@ -397,9 +446,9 @@ impl SseStreamBufferTrait for ResponsesAPIStreamBuffer {
let mut events = Vec::new(); let mut events = Vec::new();
// Capture upstream metadata from ResponseCreated or ResponseInProgress if present // Capture upstream metadata from ResponseCreated or ResponseInProgress if present
match stream_event { match stream_event.as_ref() {
ResponsesAPIStreamEvent::ResponseCreated { response, .. } | ResponsesAPIStreamEvent::ResponseCreated { response, .. }
ResponsesAPIStreamEvent::ResponseInProgress { response, .. } => { | ResponsesAPIStreamEvent::ResponseInProgress { response, .. } => {
if self.upstream_response_metadata.is_none() { if self.upstream_response_metadata.is_none() {
// Store the full upstream response as our metadata template // Store the full upstream response as our metadata template
self.upstream_response_metadata = Some(response.clone()); self.upstream_response_metadata = Some(response.clone());
@ -418,11 +467,16 @@ impl SseStreamBufferTrait for ResponsesAPIStreamBuffer {
if !self.created_emitted { if !self.created_emitted {
// Initialize metadata from first event if needed // Initialize metadata from first event if needed
if self.response_id.is_none() { if self.response_id.is_none() {
self.response_id = Some(format!("resp_{}", uuid::Uuid::new_v4().to_string().replace("-", ""))); self.response_id = Some(format!(
self.created_at = Some(std::time::SystemTime::now() "resp_{}",
.duration_since(std::time::UNIX_EPOCH) uuid::Uuid::new_v4().to_string().replace("-", "")
.unwrap() ));
.as_secs() as i64); self.created_at = Some(
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs() as i64,
);
self.model = Some("unknown".to_string()); // Will be set by caller if available self.model = Some("unknown".to_string()); // Will be set by caller if available
} }
@ -436,58 +490,95 @@ impl SseStreamBufferTrait for ResponsesAPIStreamBuffer {
} }
// Process the delta event // Process the delta event
match stream_event { match stream_event.as_ref() {
ResponsesAPIStreamEvent::ResponseOutputTextDelta { output_index, delta, .. } => { ResponsesAPIStreamEvent::ResponseOutputTextDelta {
output_index,
delta,
..
} => {
let item_id = self.get_or_create_item_id(*output_index, "msg"); let item_id = self.get_or_create_item_id(*output_index, "msg");
// Emit output_item.added if this is the first time we see this output index // Emit output_item.added if this is the first time we see this output index
if !self.output_items_added.contains_key(output_index) { if !self.output_items_added.contains_key(output_index) {
self.output_items_added.insert(*output_index, item_id.clone()); self.output_items_added
.insert(*output_index, item_id.clone());
events.push(self.create_output_item_added_event(*output_index, &item_id)); events.push(self.create_output_item_added_event(*output_index, &item_id));
} }
// Accumulate text content // Accumulate text content
self.text_content.entry(item_id.clone()) self.text_content
.entry(item_id.clone())
.and_modify(|content| content.push_str(delta)) .and_modify(|content| content.push_str(delta))
.or_insert_with(|| delta.clone()); .or_insert_with(|| delta.clone());
// Emit text delta with filled-in item_id and sequence_number // Emit text delta with filled-in item_id and sequence_number
let mut delta_event = stream_event.clone(); let mut delta_event = stream_event.as_ref().clone();
if let ResponsesAPIStreamEvent::ResponseOutputTextDelta { item_id: ref mut id, sequence_number: ref mut seq, .. } = delta_event { if let ResponsesAPIStreamEvent::ResponseOutputTextDelta {
item_id: ref mut id,
sequence_number: ref mut seq,
..
} = &mut delta_event
{
*id = item_id; *id = item_id;
*seq = self.next_sequence_number(); *seq = self.next_sequence_number();
} }
events.push(event_to_sse(delta_event)); events.push(event_to_sse(delta_event));
} }
ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDelta { output_index, delta, call_id, name, .. } => { ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDelta {
output_index,
delta,
call_id,
name,
..
} => {
let item_id = self.get_or_create_item_id(*output_index, "fc"); let item_id = self.get_or_create_item_id(*output_index, "fc");
// Store metadata if provided (from initial tool call event) // Store metadata if provided (from initial tool call event)
if let (Some(cid), Some(n)) = (call_id, name) { if let (Some(cid), Some(n)) = (call_id, name) {
self.tool_call_metadata.insert(*output_index, (cid.clone(), n.clone())); self.tool_call_metadata
.insert(*output_index, (cid.clone(), n.clone()));
} }
// Emit output_item.added if this is the first time we see this tool call // Emit output_item.added if this is the first time we see this tool call
if !self.output_items_added.contains_key(output_index) { if !self.output_items_added.contains_key(output_index) {
self.output_items_added.insert(*output_index, item_id.clone()); self.output_items_added
.insert(*output_index, item_id.clone());
// For tool calls, we need call_id and name from metadata // For tool calls, we need call_id and name from metadata
// These should now be populated from the event itself // These should now be populated from the event itself
let (call_id, name) = self.tool_call_metadata.get(output_index) let (call_id, name) = self
.tool_call_metadata
.get(output_index)
.cloned() .cloned()
.unwrap_or_else(|| (format!("call_{}", uuid::Uuid::new_v4()), "unknown".to_string())); .unwrap_or_else(|| {
(
format!("call_{}", uuid::Uuid::new_v4()),
"unknown".to_string(),
)
});
events.push(self.create_tool_call_added_event(*output_index, &item_id, &call_id, &name)); events.push(self.create_tool_call_added_event(
*output_index,
&item_id,
&call_id,
&name,
));
} }
// Accumulate function arguments // Accumulate function arguments
self.function_arguments.entry(item_id.clone()) self.function_arguments
.entry(item_id.clone())
.and_modify(|args| args.push_str(delta)) .and_modify(|args| args.push_str(delta))
.or_insert_with(|| delta.clone()); .or_insert_with(|| delta.clone());
// Emit function call arguments delta with filled-in item_id and sequence_number // Emit function call arguments delta with filled-in item_id and sequence_number
let mut delta_event = stream_event.clone(); let mut delta_event = stream_event.as_ref().clone();
if let ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDelta { item_id: ref mut id, sequence_number: ref mut seq, .. } = delta_event { if let ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDelta {
item_id: ref mut id,
sequence_number: ref mut seq,
..
} = &mut delta_event
{
*id = item_id; *id = item_id;
*seq = self.next_sequence_number(); *seq = self.next_sequence_number();
} }
@ -495,7 +586,7 @@ impl SseStreamBufferTrait for ResponsesAPIStreamBuffer {
} }
_ => { _ => {
// For other event types, just pass through with sequence number // For other event types, just pass through with sequence number
let other_event = stream_event.clone(); let other_event = stream_event.as_ref().clone();
// TODO: Add sequence number to other event types if needed // TODO: Add sequence number to other event types if needed
events.push(event_to_sse(other_event)); events.push(event_to_sse(other_event));
} }
@ -505,8 +596,7 @@ impl SseStreamBufferTrait for ResponsesAPIStreamBuffer {
self.buffered_events.extend(events); self.buffered_events.extend(events);
} }
fn to_bytes(&mut self) -> Vec<u8> {
fn into_bytes(&mut self) -> Vec<u8> {
// For Responses API, we need special handling: // For Responses API, we need special handling:
// - Most events are already in buffered_events from add_transformed_event // - Most events are already in buffered_events from add_transformed_event
// - We should NOT finalize here - finalization happens when we detect [DONE] or end of stream // - We should NOT finalize here - finalization happens when we detect [DONE] or end of stream
@ -525,9 +615,9 @@ impl SseStreamBufferTrait for ResponsesAPIStreamBuffer {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::clients::{SupportedAPIsFromClient, SupportedUpstreamAPIs};
use crate::apis::openai::OpenAIApi; use crate::apis::openai::OpenAIApi;
use crate::apis::streaming_shapes::sse::SseStreamIter; use crate::apis::streaming_shapes::sse::SseStreamIter;
use crate::clients::{SupportedAPIsFromClient, SupportedUpstreamAPIs};
#[test] #[test]
fn test_chat_completions_to_responses_api_transformation() { fn test_chat_completions_to_responses_api_transformation() {
@ -557,11 +647,12 @@ mod tests {
for raw_event in stream_iter { for raw_event in stream_iter {
// Transform the event using the client/upstream APIs // Transform the event using the client/upstream APIs
let transformed_event = SseEvent::try_from((raw_event, &client_api, &upstream_api)).unwrap(); let transformed_event =
SseEvent::try_from((raw_event, &client_api, &upstream_api)).unwrap();
buffer.add_transformed_event(transformed_event); buffer.add_transformed_event(transformed_event);
} }
let output_bytes = buffer.into_bytes(); let output_bytes = buffer.to_bytes();
let output = String::from_utf8_lossy(&output_bytes); let output = String::from_utf8_lossy(&output_bytes);
println!("\nTRANSFORMED OUTPUT (ResponsesAPI):"); println!("\nTRANSFORMED OUTPUT (ResponsesAPI):");
@ -570,13 +661,34 @@ mod tests {
// Assertions // Assertions
assert!(!output_bytes.is_empty(), "Should have output"); assert!(!output_bytes.is_empty(), "Should have output");
assert!(output.contains("response.created"), "Should have response.created"); assert!(
assert!(output.contains("response.in_progress"), "Should have response.in_progress"); output.contains("response.created"),
assert!(output.contains("response.output_item.added"), "Should have output_item.added"); "Should have response.created"
assert!(output.contains("response.output_text.delta"), "Should have text deltas"); );
assert!(output.contains("response.output_text.done"), "Should have text.done"); assert!(
assert!(output.contains("response.output_item.done"), "Should have output_item.done"); output.contains("response.in_progress"),
assert!(output.contains("response.completed"), "Should have response.completed"); "Should have response.in_progress"
);
assert!(
output.contains("response.output_item.added"),
"Should have output_item.added"
);
assert!(
output.contains("response.output_text.delta"),
"Should have text deltas"
);
assert!(
output.contains("response.output_text.done"),
"Should have text.done"
);
assert!(
output.contains("response.output_item.done"),
"Should have output_item.done"
);
assert!(
output.contains("response.completed"),
"Should have response.completed"
);
println!("\nVALIDATION SUMMARY:"); println!("\nVALIDATION SUMMARY:");
println!("{}", "-".repeat(80)); println!("{}", "-".repeat(80));
@ -616,7 +728,7 @@ mod tests {
buffer.add_transformed_event(transformed); buffer.add_transformed_event(transformed);
} }
let output_bytes = buffer.into_bytes(); let output_bytes = buffer.to_bytes();
let output = String::from_utf8_lossy(&output_bytes); let output = String::from_utf8_lossy(&output_bytes);
println!("\nTRANSFORMED OUTPUT (ResponsesAPI):"); println!("\nTRANSFORMED OUTPUT (ResponsesAPI):");
@ -624,24 +736,55 @@ mod tests {
println!("{}", output); println!("{}", output);
// Assertions // Assertions
assert!(output.contains("response.created"), "Should have response.created"); assert!(
assert!(output.contains("response.in_progress"), "Should have response.in_progress"); output.contains("response.created"),
assert!(output.contains("response.output_item.added"), "Should have output_item.added"); "Should have response.created"
assert!(output.contains("\"type\":\"function_call\""), "Should be function_call type"); );
assert!(output.contains("\"name\":\"get_weather\""), "Should have function name"); assert!(
assert!(output.contains("\"call_id\":\"call_mD5ggLKk3SMKGPFqFdcpKg6q\""), "Should have correct call_id"); output.contains("response.in_progress"),
"Should have response.in_progress"
);
assert!(
output.contains("response.output_item.added"),
"Should have output_item.added"
);
assert!(
output.contains("\"type\":\"function_call\""),
"Should be function_call type"
);
assert!(
output.contains("\"name\":\"get_weather\""),
"Should have function name"
);
assert!(
output.contains("\"call_id\":\"call_mD5ggLKk3SMKGPFqFdcpKg6q\""),
"Should have correct call_id"
);
let delta_count = output.matches("event: response.function_call_arguments.delta").count(); let delta_count = output
.matches("event: response.function_call_arguments.delta")
.count();
assert_eq!(delta_count, 4, "Should have 4 delta events"); assert_eq!(delta_count, 4, "Should have 4 delta events");
assert!(!output.contains("response.function_call_arguments.done"), "Should NOT have arguments.done"); assert!(
assert!(!output.contains("response.output_item.done"), "Should NOT have output_item.done"); !output.contains("response.function_call_arguments.done"),
assert!(!output.contains("response.completed"), "Should NOT have response.completed"); "Should NOT have arguments.done"
);
assert!(
!output.contains("response.output_item.done"),
"Should NOT have output_item.done"
);
assert!(
!output.contains("response.completed"),
"Should NOT have response.completed"
);
println!("\nVALIDATION SUMMARY:"); println!("\nVALIDATION SUMMARY:");
println!("{}", "-".repeat(80)); println!("{}", "-".repeat(80));
println!("✓ Lifecycle events: response.created, response.in_progress"); println!("✓ Lifecycle events: response.created, response.in_progress");
println!("✓ Function call metadata: name='get_weather', call_id='call_mD5ggLKk3SMKGPFqFdcpKg6q'"); println!(
"✓ Function call metadata: name='get_weather', call_id='call_mD5ggLKk3SMKGPFqFdcpKg6q'"
);
println!("✓ Incremental deltas: 4 events (1 initial + 3 argument chunks)"); println!("✓ Incremental deltas: 4 events (1 initial + 3 argument chunks)");
println!("✓ NO completion events (partial stream, no [DONE])"); println!("✓ NO completion events (partial stream, no [DONE])");
println!("✓ Arguments accumulated: '{{\"location\":\"'\n"); println!("✓ Arguments accumulated: '{{\"location\":\"'\n");

View file

@ -1,9 +1,9 @@
use crate::providers::streaming_response::ProviderStreamResponse;
use crate::providers::streaming_response::ProviderStreamResponseType;
use crate::apis::streaming_shapes::chat_completions_streaming_buffer::OpenAIChatCompletionsStreamBuffer;
use crate::apis::streaming_shapes::anthropic_streaming_buffer::AnthropicMessagesStreamBuffer; use crate::apis::streaming_shapes::anthropic_streaming_buffer::AnthropicMessagesStreamBuffer;
use crate::apis::streaming_shapes::chat_completions_streaming_buffer::OpenAIChatCompletionsStreamBuffer;
use crate::apis::streaming_shapes::passthrough_streaming_buffer::PassthroughStreamBuffer; use crate::apis::streaming_shapes::passthrough_streaming_buffer::PassthroughStreamBuffer;
use crate::apis::streaming_shapes::responses_api_streaming_buffer::ResponsesAPIStreamBuffer; use crate::apis::streaming_shapes::responses_api_streaming_buffer::ResponsesAPIStreamBuffer;
use crate::providers::streaming_response::ProviderStreamResponse;
use crate::providers::streaming_response::ProviderStreamResponseType;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::error::Error; use std::error::Error;
use std::fmt; use std::fmt;
@ -37,7 +37,7 @@ pub trait SseStreamBufferTrait: Send + Sync {
/// ///
/// # Returns /// # Returns
/// Bytes ready for wire transmission (may be empty if no events were accumulated) /// Bytes ready for wire transmission (may be empty if no events were accumulated)
fn into_bytes(&mut self) -> Vec<u8>; fn to_bytes(&mut self) -> Vec<u8>;
} }
/// Unified SSE Stream Buffer enum that provides a zero-cost abstraction /// Unified SSE Stream Buffer enum that provides a zero-cost abstraction
@ -45,7 +45,7 @@ pub enum SseStreamBuffer {
Passthrough(PassthroughStreamBuffer), Passthrough(PassthroughStreamBuffer),
OpenAIChatCompletions(OpenAIChatCompletionsStreamBuffer), OpenAIChatCompletions(OpenAIChatCompletionsStreamBuffer),
AnthropicMessages(AnthropicMessagesStreamBuffer), AnthropicMessages(AnthropicMessagesStreamBuffer),
OpenAIResponses(ResponsesAPIStreamBuffer), OpenAIResponses(Box<ResponsesAPIStreamBuffer>),
} }
impl SseStreamBufferTrait for SseStreamBuffer { impl SseStreamBufferTrait for SseStreamBuffer {
@ -58,12 +58,12 @@ impl SseStreamBufferTrait for SseStreamBuffer {
} }
} }
fn into_bytes(&mut self) -> Vec<u8> { fn to_bytes(&mut self) -> Vec<u8> {
match self { match self {
Self::Passthrough(buffer) => buffer.into_bytes(), Self::Passthrough(buffer) => buffer.to_bytes(),
Self::OpenAIChatCompletions(buffer) => buffer.into_bytes(), Self::OpenAIChatCompletions(buffer) => buffer.to_bytes(),
Self::AnthropicMessages(buffer) => buffer.into_bytes(), Self::AnthropicMessages(buffer) => buffer.to_bytes(),
Self::OpenAIResponses(buffer) => buffer.into_bytes(), Self::OpenAIResponses(buffer) => buffer.to_bytes(),
} }
} }
} }
@ -99,7 +99,7 @@ impl SseEvent {
let sse_string: String = response.clone().into(); let sse_string: String = response.clone().into();
SseEvent { SseEvent {
data: None, // Data is embedded in sse_transformed_lines data: None, // Data is embedded in sse_transformed_lines
event: None, // Event type is embedded in sse_transformed_lines event: None, // Event type is embedded in sse_transformed_lines
raw_line: sse_string.clone(), raw_line: sse_string.clone(),
sse_transformed_lines: sse_string, sse_transformed_lines: sse_string,
@ -149,10 +149,8 @@ impl FromStr for SseEvent {
}); });
} }
if trimmed_line.starts_with("data: ") { if let Some(stripped) = trimmed_line.strip_prefix("data: ") {
let data: String = trimmed_line[6..].to_string(); // Remove "data: " prefix let data: String = stripped.to_string();
// Allow empty data content after "data: " prefix
// This handles cases like "data: " followed by newline
if data.trim().is_empty() { if data.trim().is_empty() {
return Err(SseParseError { return Err(SseParseError {
message: "Empty data field after 'data: ' prefix".to_string(), message: "Empty data field after 'data: ' prefix".to_string(),
@ -166,8 +164,8 @@ impl FromStr for SseEvent {
sse_transformed_lines: line.to_string(), sse_transformed_lines: line.to_string(),
provider_stream_response: None, provider_stream_response: None,
}) })
} else if trimmed_line.starts_with("event: ") { } else if let Some(stripped) = trimmed_line.strip_prefix("event: ") {
let event_type = trimmed_line[7..].to_string(); let event_type = stripped.to_string();
if event_type.is_empty() { if event_type.is_empty() {
return Err(SseParseError { return Err(SseParseError {
message: "Empty event field is not a valid SSE event".to_string(), message: "Empty event field is not a valid SSE event".to_string(),
@ -183,7 +181,10 @@ impl FromStr for SseEvent {
}) })
} else { } else {
Err(SseParseError { Err(SseParseError {
message: format!("Line does not start with 'data: ' or 'event: ': {}", trimmed_line), message: format!(
"Line does not start with 'data: ' or 'event: ': {}",
trimmed_line
),
}) })
} }
} }
@ -196,16 +197,16 @@ impl fmt::Display for SseEvent {
} }
// Into implementation to convert SseEvent to bytes for response buffer // Into implementation to convert SseEvent to bytes for response buffer
impl Into<Vec<u8>> for SseEvent { impl From<SseEvent> for Vec<u8> {
fn into(self) -> Vec<u8> { fn from(val: SseEvent) -> Self {
// For generated events (like ResponsesAPI), sse_transformed_lines already includes trailing \n\n // For generated events (like ResponsesAPI), sse_transformed_lines already includes trailing \n\n
// For parsed events (like passthrough), we need to add the \n\n separator // For parsed events (like passthrough), we need to add the \n\n separator
if self.sse_transformed_lines.ends_with("\n\n") { if val.sse_transformed_lines.ends_with("\n\n") {
// Already properly formatted with trailing newlines // Already properly formatted with trailing newlines
self.sse_transformed_lines.into_bytes() val.sse_transformed_lines.into_bytes()
} else { } else {
// Add SSE event separator // Add SSE event separator
format!("{}\n\n", self.sse_transformed_lines).into_bytes() format!("{}\n\n", val.sse_transformed_lines).into_bytes()
} }
} }
} }

View file

@ -10,6 +10,12 @@ pub struct SseChunkProcessor {
incomplete_event_buffer: Vec<u8>, incomplete_event_buffer: Vec<u8>,
} }
impl Default for SseChunkProcessor {
fn default() -> Self {
Self::new()
}
}
impl SseChunkProcessor { impl SseChunkProcessor {
pub fn new() -> Self { pub fn new() -> Self {
Self { Self {
@ -93,8 +99,8 @@ impl SseChunkProcessor {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::clients::endpoints::{SupportedAPIsFromClient, SupportedUpstreamAPIs};
use crate::apis::openai::OpenAIApi; use crate::apis::openai::OpenAIApi;
use crate::clients::endpoints::{SupportedAPIsFromClient, SupportedUpstreamAPIs};
#[test] #[test]
fn test_complete_events_process_immediately() { fn test_complete_events_process_immediately() {
@ -104,7 +110,9 @@ mod tests {
let chunk1 = b"data: {\"id\":\"chatcmpl-123\",\"object\":\"chat.completion.chunk\",\"created\":1234567890,\"model\":\"gpt-4o\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"Hello\"},\"finish_reason\":null}]}\n\n"; let chunk1 = b"data: {\"id\":\"chatcmpl-123\",\"object\":\"chat.completion.chunk\",\"created\":1234567890,\"model\":\"gpt-4o\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"Hello\"},\"finish_reason\":null}]}\n\n";
let events = processor.process_chunk(chunk1, &client_api, &upstream_api).unwrap(); let events = processor
.process_chunk(chunk1, &client_api, &upstream_api)
.unwrap();
assert_eq!(events.len(), 1); assert_eq!(events.len(), 1);
assert!(!processor.has_buffered_data()); assert!(!processor.has_buffered_data());
@ -119,18 +127,28 @@ mod tests {
// First chunk with incomplete JSON // First chunk with incomplete JSON
let chunk1 = b"data: {\"id\":\"chatcmpl-123\",\"object\":\"chat.completion.chu"; let chunk1 = b"data: {\"id\":\"chatcmpl-123\",\"object\":\"chat.completion.chu";
let events1 = processor.process_chunk(chunk1, &client_api, &upstream_api).unwrap(); let events1 = processor
.process_chunk(chunk1, &client_api, &upstream_api)
.unwrap();
assert_eq!(events1.len(), 0, "Incomplete event should not be processed"); assert_eq!(events1.len(), 0, "Incomplete event should not be processed");
assert!(processor.has_buffered_data(), "Incomplete data should be buffered"); assert!(
processor.has_buffered_data(),
"Incomplete data should be buffered"
);
// Second chunk completes the JSON // Second chunk completes the JSON
let chunk2 = b"nk\",\"created\":1234567890,\"model\":\"gpt-4o\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"Hello\"},\"finish_reason\":null}]}\n\n"; let chunk2 = b"nk\",\"created\":1234567890,\"model\":\"gpt-4o\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"Hello\"},\"finish_reason\":null}]}\n\n";
let events2 = processor.process_chunk(chunk2, &client_api, &upstream_api).unwrap(); let events2 = processor
.process_chunk(chunk2, &client_api, &upstream_api)
.unwrap();
assert_eq!(events2.len(), 1, "Complete event should be processed"); assert_eq!(events2.len(), 1, "Complete event should be processed");
assert!(!processor.has_buffered_data(), "Buffer should be cleared after completion"); assert!(
!processor.has_buffered_data(),
"Buffer should be cleared after completion"
);
} }
#[test] #[test]
@ -142,10 +160,15 @@ mod tests {
// Chunk with 2 complete events and 1 incomplete // Chunk with 2 complete events and 1 incomplete
let chunk = b"data: {\"id\":\"chatcmpl-123\",\"object\":\"chat.completion.chunk\",\"created\":1234567890,\"model\":\"gpt-4o\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"A\"},\"finish_reason\":null}]}\n\ndata: {\"id\":\"chatcmpl-124\",\"object\":\"chat.completion.chunk\",\"created\":1234567890,\"model\":\"gpt-4o\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"B\"},\"finish_reason\":null}]}\n\ndata: {\"id\":\"chatcmpl-125\",\"object\":\"chat.completion.chu"; let chunk = b"data: {\"id\":\"chatcmpl-123\",\"object\":\"chat.completion.chunk\",\"created\":1234567890,\"model\":\"gpt-4o\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"A\"},\"finish_reason\":null}]}\n\ndata: {\"id\":\"chatcmpl-124\",\"object\":\"chat.completion.chunk\",\"created\":1234567890,\"model\":\"gpt-4o\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"B\"},\"finish_reason\":null}]}\n\ndata: {\"id\":\"chatcmpl-125\",\"object\":\"chat.completion.chu";
let events = processor.process_chunk(chunk, &client_api, &upstream_api).unwrap(); let events = processor
.process_chunk(chunk, &client_api, &upstream_api)
.unwrap();
assert_eq!(events.len(), 2, "Two complete events should be processed"); assert_eq!(events.len(), 2, "Two complete events should be processed");
assert!(processor.has_buffered_data(), "Incomplete third event should be buffered"); assert!(
processor.has_buffered_data(),
"Incomplete third event should be buffered"
);
} }
#[test] #[test]
@ -171,11 +194,23 @@ data: {"type":"content_block_stop","index":0}
Ok(events) => { Ok(events) => {
println!("Successfully processed {} events", events.len()); println!("Successfully processed {} events", events.len());
for (i, event) in events.iter().enumerate() { for (i, event) in events.iter().enumerate() {
println!("Event {}: event={:?}, has_data={}", i, event.event, event.data.is_some()); println!(
"Event {}: event={:?}, has_data={}",
i,
event.event,
event.data.is_some()
);
} }
// Should successfully process both events (signature_delta + content_block_stop) // Should successfully process both events (signature_delta + content_block_stop)
assert!(events.len() >= 2, "Should process at least 2 complete events (signature_delta + stop), got {}", events.len()); assert!(
assert!(!processor.has_buffered_data(), "Complete events should not be buffered"); events.len() >= 2,
"Should process at least 2 complete events (signature_delta + stop), got {}",
events.len()
);
assert!(
!processor.has_buffered_data(),
"Complete events should not be buffered"
);
} }
Err(e) => { Err(e) => {
panic!("Failed to process signature_delta chunk - this means SignatureDelta is not properly handled: {}", e); panic!("Failed to process signature_delta chunk - this means SignatureDelta is not properly handled: {}", e);
@ -194,12 +229,21 @@ data: {"type":"content_block_stop","index":0}
// Second event is valid and should be processed // Second event is valid and should be processed
let chunk = b"data: {\"id\":\"chatcmpl-123\",\"object\":\"chat.completion.chunk\",\"created\":1234567890,\"model\":\"gpt-4o\",\"choices\":[{\"index\":0,\"delta\":{\"unsupported_field_causing_validation_error\":true},\"finish_reason\":null}]}\n\ndata: {\"id\":\"chatcmpl-124\",\"object\":\"chat.completion.chunk\",\"created\":1234567890,\"model\":\"gpt-4o\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"Hello\"},\"finish_reason\":null}]}\n\n"; let chunk = b"data: {\"id\":\"chatcmpl-123\",\"object\":\"chat.completion.chunk\",\"created\":1234567890,\"model\":\"gpt-4o\",\"choices\":[{\"index\":0,\"delta\":{\"unsupported_field_causing_validation_error\":true},\"finish_reason\":null}]}\n\ndata: {\"id\":\"chatcmpl-124\",\"object\":\"chat.completion.chunk\",\"created\":1234567890,\"model\":\"gpt-4o\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"Hello\"},\"finish_reason\":null}]}\n\n";
let events = processor.process_chunk(chunk, &client_api, &upstream_api).unwrap(); let events = processor
.process_chunk(chunk, &client_api, &upstream_api)
.unwrap();
// Should skip the invalid event and process the valid one // Should skip the invalid event and process the valid one
// (If we were buffering all errors, we'd get 0 events and have buffered data) // (If we were buffering all errors, we'd get 0 events and have buffered data)
assert!(events.len() >= 1, "Should process at least the valid event, got {} events", events.len()); assert!(
assert!(!processor.has_buffered_data(), "Invalid (non-incomplete) events should not be buffered"); !events.is_empty(),
"Should process at least the valid event, got {} events",
events.len()
);
assert!(
!processor.has_buffered_data(),
"Invalid (non-incomplete) events should not be buffered"
);
} }
#[test] #[test]
@ -227,14 +271,27 @@ data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text
match result { match result {
Ok(events) => { Ok(events) => {
println!("Processed {} events (unsupported event should be skipped)", events.len()); println!(
"Processed {} events (unsupported event should be skipped)",
events.len()
);
// Should process the 2 valid text_delta events and skip the unsupported one // Should process the 2 valid text_delta events and skip the unsupported one
// We expect at least 2 events (the valid ones), unsupported should be skipped // We expect at least 2 events (the valid ones), unsupported should be skipped
assert!(events.len() >= 2, "Should process at least 2 valid events, got {}", events.len()); assert!(
assert!(!processor.has_buffered_data(), "Unsupported events should be skipped, not buffered"); events.len() >= 2,
"Should process at least 2 valid events, got {}",
events.len()
);
assert!(
!processor.has_buffered_data(),
"Unsupported events should be skipped, not buffered"
);
} }
Err(e) => { Err(e) => {
panic!("Should not fail on unsupported delta type, should skip it: {}", e); panic!(
"Should not fail on unsupported delta type, should skip it: {}",
e
);
} }
} }
} }

View file

@ -135,7 +135,10 @@ impl SupportedAPIsFromClient {
ProviderId::AzureOpenAI => { ProviderId::AzureOpenAI => {
if request_path.starts_with("/v1/") { if request_path.starts_with("/v1/") {
let suffix = endpoint_suffix.trim_start_matches('/'); let suffix = endpoint_suffix.trim_start_matches('/');
build_endpoint("/openai/deployments", &format!("/{}/{}?api-version=2025-01-01-preview", model_id, suffix)) build_endpoint(
"/openai/deployments",
&format!("/{}/{}?api-version=2025-01-01-preview", model_id, suffix),
)
} else { } else {
build_endpoint("/v1", endpoint_suffix) build_endpoint("/v1", endpoint_suffix)
} }
@ -163,19 +166,21 @@ impl SupportedAPIsFromClient {
}; };
match self { match self {
SupportedAPIsFromClient::AnthropicMessagesAPI(AnthropicApi::Messages) => match provider_id { SupportedAPIsFromClient::AnthropicMessagesAPI(AnthropicApi::Messages) => {
ProviderId::Anthropic => build_endpoint("/v1", "/messages"), match provider_id {
ProviderId::AmazonBedrock => { ProviderId::Anthropic => build_endpoint("/v1", "/messages"),
if request_path.starts_with("/v1/") && !is_streaming { ProviderId::AmazonBedrock => {
build_endpoint("", &format!("/model/{}/converse", model_id)) if request_path.starts_with("/v1/") && !is_streaming {
} else if request_path.starts_with("/v1/") && is_streaming { build_endpoint("", &format!("/model/{}/converse", model_id))
build_endpoint("", &format!("/model/{}/converse-stream", model_id)) } else if request_path.starts_with("/v1/") && is_streaming {
} else { build_endpoint("", &format!("/model/{}/converse-stream", model_id))
build_endpoint("/v1", "/chat/completions") } else {
build_endpoint("/v1", "/chat/completions")
}
} }
_ => build_endpoint("/v1", "/chat/completions"),
} }
_ => build_endpoint("/v1", "/chat/completions"), }
},
SupportedAPIsFromClient::OpenAIResponsesAPI(_) => { SupportedAPIsFromClient::OpenAIResponsesAPI(_) => {
// For Responses API, check if provider supports it, otherwise translate to chat/completions // For Responses API, check if provider supports it, otherwise translate to chat/completions
match provider_id { match provider_id {
@ -193,7 +198,6 @@ impl SupportedAPIsFromClient {
} }
} }
impl SupportedUpstreamAPIs { impl SupportedUpstreamAPIs {
/// Create a SupportedUpstreamApi from an endpoint path /// Create a SupportedUpstreamApi from an endpoint path
pub fn from_endpoint(endpoint: &str) -> Option<Self> { pub fn from_endpoint(endpoint: &str) -> Option<Self> {
@ -216,17 +220,17 @@ impl SupportedUpstreamAPIs {
return Some(SupportedUpstreamAPIs::AmazonBedrockConverse(bedrock_api)) return Some(SupportedUpstreamAPIs::AmazonBedrockConverse(bedrock_api))
} }
AmazonBedrockApi::ConverseStream => { AmazonBedrockApi::ConverseStream => {
return Some(SupportedUpstreamAPIs::AmazonBedrockConverseStream(bedrock_api)) return Some(SupportedUpstreamAPIs::AmazonBedrockConverseStream(
bedrock_api,
))
} }
} }
} }
None None
} }
} }
/// Get all supported endpoint paths /// Get all supported endpoint paths
pub fn supported_endpoints() -> Vec<&'static str> { pub fn supported_endpoints() -> Vec<&'static str> {
let mut endpoints = Vec::new(); let mut endpoints = Vec::new();
@ -269,9 +273,9 @@ mod tests {
assert!(SupportedAPIsFromClient::from_endpoint("/v1/messages").is_some()); assert!(SupportedAPIsFromClient::from_endpoint("/v1/messages").is_some());
// Unsupported endpoints // Unsupported endpoints
assert!(!SupportedAPIsFromClient::from_endpoint("/v1/unknown").is_some()); assert!(SupportedAPIsFromClient::from_endpoint("/v1/unknown").is_none());
assert!(!SupportedAPIsFromClient::from_endpoint("/v2/chat").is_some()); assert!(SupportedAPIsFromClient::from_endpoint("/v2/chat").is_none());
assert!(!SupportedAPIsFromClient::from_endpoint("").is_some()); assert!(SupportedAPIsFromClient::from_endpoint("").is_none());
} }
#[test] #[test]

View file

@ -12,11 +12,9 @@ pub use aws_smithy_eventstream::frame::DecodedFrame;
pub use providers::id::ProviderId; pub use providers::id::ProviderId;
pub use providers::request::{ProviderRequest, ProviderRequestError, ProviderRequestType}; pub use providers::request::{ProviderRequest, ProviderRequestError, ProviderRequestType};
pub use providers::response::{ pub use providers::response::{
ProviderResponse, ProviderResponseType, TokenUsage, ProviderResponseError ProviderResponse, ProviderResponseError, ProviderResponseType, TokenUsage,
};
pub use providers::streaming_response::{
ProviderStreamResponse, ProviderStreamResponseType
}; };
pub use providers::streaming_response::{ProviderStreamResponse, ProviderStreamResponseType};
//TODO: Refactor such that commons doesn't depend on Hermes. For now this will clean up strings //TODO: Refactor such that commons doesn't depend on Hermes. For now this will clean up strings
pub const CHAT_COMPLETIONS_PATH: &str = "/v1/chat/completions"; pub const CHAT_COMPLETIONS_PATH: &str = "/v1/chat/completions";
@ -87,11 +85,17 @@ mod tests {
let done_event = streaming_iter.next(); let done_event = streaming_iter.next();
assert!(done_event.is_some(), "Should get [DONE] event"); assert!(done_event.is_some(), "Should get [DONE] event");
let done_event = done_event.unwrap(); let done_event = done_event.unwrap();
assert!(done_event.is_done(), "[DONE] event should be marked as done"); assert!(
done_event.is_done(),
"[DONE] event should be marked as done"
);
// After [DONE], iterator should return None // After [DONE], iterator should return None
let final_event = streaming_iter.next(); let final_event = streaming_iter.next();
assert!(final_event.is_none(), "Iterator should return None after [DONE]"); assert!(
final_event.is_none(),
"Iterator should return None after [DONE]"
);
} }
/// Test AWS Event Stream decoding for Bedrock ConverseStream responses. /// Test AWS Event Stream decoding for Bedrock ConverseStream responses.
@ -130,7 +134,7 @@ mod tests {
let mut content_chunks = Vec::new(); let mut content_chunks = Vec::new();
// Simulate chunked network arrivals - process as data comes in // Simulate chunked network arrivals - process as data comes in
let chunk_sizes = vec![50, 100, 75, 200, 150, 300, 500, 1000]; let chunk_sizes = [50, 100, 75, 200, 150, 300, 500, 1000];
let mut offset = 0; let mut offset = 0;
let mut chunk_num = 0; let mut chunk_num = 0;

View file

@ -59,10 +59,9 @@ impl ProviderId {
(ProviderId::Anthropic, SupportedAPIsFromClient::AnthropicMessagesAPI(_)) => { (ProviderId::Anthropic, SupportedAPIsFromClient::AnthropicMessagesAPI(_)) => {
SupportedUpstreamAPIs::AnthropicMessagesAPI(AnthropicApi::Messages) SupportedUpstreamAPIs::AnthropicMessagesAPI(AnthropicApi::Messages)
} }
( (ProviderId::Anthropic, SupportedAPIsFromClient::OpenAIChatCompletions(_)) => {
ProviderId::Anthropic, SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions)
SupportedAPIsFromClient::OpenAIChatCompletions(_), }
) => SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
// Anthropic doesn't support Responses API, fall back to chat completions // Anthropic doesn't support Responses API, fall back to chat completions
(ProviderId::Anthropic, SupportedAPIsFromClient::OpenAIResponsesAPI(_)) => { (ProviderId::Anthropic, SupportedAPIsFromClient::OpenAIResponsesAPI(_)) => {

View file

@ -10,6 +10,7 @@ use serde_json::Value;
use std::collections::HashMap; use std::collections::HashMap;
use std::error::Error; use std::error::Error;
use std::fmt; use std::fmt;
#[allow(clippy::large_enum_variant)]
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub enum ProviderRequestType { pub enum ProviderRequestType {
ChatCompletionsRequest(ChatCompletionsRequest), ChatCompletionsRequest(ChatCompletionsRequest),
@ -197,7 +198,9 @@ impl ProviderRequest for ProviderRequestType {
impl TryFrom<(&[u8], &SupportedAPIsFromClient)> for ProviderRequestType { impl TryFrom<(&[u8], &SupportedAPIsFromClient)> for ProviderRequestType {
type Error = std::io::Error; type Error = std::io::Error;
fn try_from((bytes, client_api): (&[u8], &SupportedAPIsFromClient)) -> Result<Self, Self::Error> { fn try_from(
(bytes, client_api): (&[u8], &SupportedAPIsFromClient),
) -> Result<Self, Self::Error> {
// Use SupportedApi to determine the appropriate request type // Use SupportedApi to determine the appropriate request type
match client_api { match client_api {
SupportedAPIsFromClient::OpenAIChatCompletions(_) => { SupportedAPIsFromClient::OpenAIChatCompletions(_) => {
@ -882,7 +885,7 @@ mod tests {
ProviderRequestType::BedrockConverse(bedrock_req) => { ProviderRequestType::BedrockConverse(bedrock_req) => {
assert_eq!(bedrock_req.model_id, "gpt-4o"); assert_eq!(bedrock_req.model_id, "gpt-4o");
// Bedrock receives the converted request through ChatCompletions // Bedrock receives the converted request through ChatCompletions
assert!(!bedrock_req.messages.is_none()); assert!(bedrock_req.messages.is_some());
} }
_ => panic!("Expected BedrockConverse variant"), _ => panic!("Expected BedrockConverse variant"),
} }
@ -913,7 +916,9 @@ mod tests {
assert!(result.is_err()); assert!(result.is_err());
let err = result.unwrap_err(); let err = result.unwrap_err();
assert!(err.message.contains("ResponsesAPI can only be used as a client API")); assert!(err
.message
.contains("ResponsesAPI can only be used as a client API"));
} }
#[test] #[test]
@ -953,7 +958,9 @@ mod tests {
assert!(result.is_err()); assert!(result.is_err());
let err = result.unwrap_err(); let err = result.unwrap_err();
assert!(err.message.contains("ResponsesAPI can only be used as a client API")); assert!(err
.message
.contains("ResponsesAPI can only be used as a client API"));
} }
#[test] #[test]
@ -1023,9 +1030,7 @@ mod tests {
role: MessagesRole::User, role: MessagesRole::User,
content: MessagesMessageContent::Single("Hello!".to_string()), content: MessagesMessageContent::Single("Hello!".to_string()),
}], }],
system: Some(MessagesSystemPrompt::Single( system: Some(MessagesSystemPrompt::Single("You are helpful".to_string())),
"You are helpful".to_string(),
)),
max_tokens: 100, max_tokens: 100,
container: None, container: None,
mcp_servers: None, mcp_servers: None,
@ -1046,14 +1051,8 @@ mod tests {
// Should have system message + user message // Should have system message + user message
assert_eq!(messages.len(), 2); assert_eq!(messages.len(), 2);
assert_eq!( assert_eq!(messages[0].role, crate::apis::openai::Role::System);
messages[0].role, assert_eq!(messages[1].role, crate::apis::openai::Role::User);
crate::apis::openai::Role::System
);
assert_eq!(
messages[1].role,
crate::apis::openai::Role::User
);
} }
#[test] #[test]
@ -1094,13 +1093,7 @@ mod tests {
// Should have system message (instructions) + user message (input) // Should have system message (instructions) + user message (input)
assert_eq!(messages.len(), 2); assert_eq!(messages.len(), 2);
assert_eq!( assert_eq!(messages[0].role, crate::apis::openai::Role::System);
messages[0].role, assert_eq!(messages[1].role, crate::apis::openai::Role::User);
crate::apis::openai::Role::System
);
assert_eq!(
messages[1].role,
crate::apis::openai::Role::User
);
} }
} }

View file

@ -1,7 +1,3 @@
use serde::Serialize;
use std::convert::TryFrom;
use std::error::Error;
use std::fmt;
use crate::apis::amazon_bedrock::ConverseResponse; use crate::apis::amazon_bedrock::ConverseResponse;
use crate::apis::anthropic::MessagesResponse; use crate::apis::anthropic::MessagesResponse;
use crate::apis::openai::ChatCompletionsResponse; use crate::apis::openai::ChatCompletionsResponse;
@ -9,14 +5,17 @@ use crate::apis::openai_responses::ResponsesAPIResponse;
use crate::clients::endpoints::SupportedAPIsFromClient; use crate::clients::endpoints::SupportedAPIsFromClient;
use crate::clients::endpoints::SupportedUpstreamAPIs; use crate::clients::endpoints::SupportedUpstreamAPIs;
use crate::providers::id::ProviderId; use crate::providers::id::ProviderId;
use serde::Serialize;
use std::convert::TryFrom;
use std::error::Error;
use std::fmt;
#[derive(Serialize, Debug, Clone)] #[derive(Serialize, Debug, Clone)]
#[serde(untagged)] #[serde(untagged)]
pub enum ProviderResponseType { pub enum ProviderResponseType {
ChatCompletionsResponse(ChatCompletionsResponse), ChatCompletionsResponse(ChatCompletionsResponse),
MessagesResponse(MessagesResponse), MessagesResponse(MessagesResponse),
ResponsesAPIResponse(ResponsesAPIResponse), ResponsesAPIResponse(Box<ResponsesAPIResponse>),
} }
/// Trait for token usage information /// Trait for token usage information
@ -42,7 +41,9 @@ impl ProviderResponse for ProviderResponseType {
match self { match self {
ProviderResponseType::ChatCompletionsResponse(resp) => resp.usage(), ProviderResponseType::ChatCompletionsResponse(resp) => resp.usage(),
ProviderResponseType::MessagesResponse(resp) => resp.usage(), ProviderResponseType::MessagesResponse(resp) => resp.usage(),
ProviderResponseType::ResponsesAPIResponse(resp) => resp.usage.as_ref().map(|u| u as &dyn TokenUsage), ProviderResponseType::ResponsesAPIResponse(resp) => {
resp.usage.as_ref().map(|u| u as &dyn TokenUsage)
}
} }
} }
@ -50,11 +51,13 @@ impl ProviderResponse for ProviderResponseType {
match self { match self {
ProviderResponseType::ChatCompletionsResponse(resp) => resp.extract_usage_counts(), ProviderResponseType::ChatCompletionsResponse(resp) => resp.extract_usage_counts(),
ProviderResponseType::MessagesResponse(resp) => resp.extract_usage_counts(), ProviderResponseType::MessagesResponse(resp) => resp.extract_usage_counts(),
ProviderResponseType::ResponsesAPIResponse(resp) => { ProviderResponseType::ResponsesAPIResponse(resp) => resp.usage.as_ref().map(|u| {
resp.usage.as_ref().map(|u| { (
(u.input_tokens as usize, u.output_tokens as usize, u.total_tokens as usize) u.input_tokens as usize,
}) u.output_tokens as usize,
} u.total_tokens as usize,
)
}),
} }
} }
} }
@ -156,40 +159,44 @@ impl TryFrom<(&[u8], &SupportedAPIsFromClient, &ProviderId)> for ProviderRespons
) => { ) => {
let resp: ResponsesAPIResponse = ResponsesAPIResponse::try_from(bytes) let resp: ResponsesAPIResponse = ResponsesAPIResponse::try_from(bytes)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
Ok(ProviderResponseType::ResponsesAPIResponse(resp)) Ok(ProviderResponseType::ResponsesAPIResponse(Box::new(resp)))
} }
( (
SupportedUpstreamAPIs::OpenAIChatCompletions(_), SupportedUpstreamAPIs::OpenAIChatCompletions(_),
SupportedAPIsFromClient::OpenAIResponsesAPI(_), SupportedAPIsFromClient::OpenAIResponsesAPI(_),
) => { ) => {
let chat_completions_response: ChatCompletionsResponse = ChatCompletionsResponse::try_from(bytes) let chat_completions_response: ChatCompletionsResponse =
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; ChatCompletionsResponse::try_from(bytes)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
// Transform to ResponsesAPI format using the transformer // Transform to ResponsesAPI format using the transformer
let responses_resp: ResponsesAPIResponse = chat_completions_response.try_into().map_err(|e| { let responses_resp: ResponsesAPIResponse =
std::io::Error::new( chat_completions_response.try_into().map_err(|e| {
std::io::ErrorKind::InvalidData, std::io::Error::new(
format!("Transformation error: {}", e), std::io::ErrorKind::InvalidData,
) format!("Transformation error: {}", e),
})?; )
Ok(ProviderResponseType::ResponsesAPIResponse(responses_resp)) })?;
Ok(ProviderResponseType::ResponsesAPIResponse(Box::new(
responses_resp,
)))
} }
( (
SupportedUpstreamAPIs::AnthropicMessagesAPI(_), SupportedUpstreamAPIs::AnthropicMessagesAPI(_),
SupportedAPIsFromClient::OpenAIResponsesAPI(_), SupportedAPIsFromClient::OpenAIResponsesAPI(_),
) => { ) => {
//Chain transform: Anthropic Messages -> OpenAI ChatCompletions -> ResponsesAPI //Chain transform: Anthropic Messages -> OpenAI ChatCompletions -> ResponsesAPI
let anthropic_resp: MessagesResponse = serde_json::from_slice(bytes) let anthropic_resp: MessagesResponse = serde_json::from_slice(bytes)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
// Transform to ChatCompletions format using the transformer // Transform to ChatCompletions format using the transformer
let chat_resp: ChatCompletionsResponse = anthropic_resp.try_into().map_err(|e| { let chat_resp: ChatCompletionsResponse =
std::io::Error::new( anthropic_resp.try_into().map_err(|e| {
std::io::ErrorKind::InvalidData, std::io::Error::new(
format!("Transformation error: {}", e), std::io::ErrorKind::InvalidData,
) format!("Transformation error: {}", e),
})?; )
})?;
let response_api: ResponsesAPIResponse = chat_resp.try_into().map_err(|e| { let response_api: ResponsesAPIResponse = chat_resp.try_into().map_err(|e| {
std::io::Error::new( std::io::Error::new(
@ -197,7 +204,9 @@ impl TryFrom<(&[u8], &SupportedAPIsFromClient, &ProviderId)> for ProviderRespons
format!("Transformation error: {}", e), format!("Transformation error: {}", e),
) )
})?; })?;
Ok(ProviderResponseType::ResponsesAPIResponse(response_api)) Ok(ProviderResponseType::ResponsesAPIResponse(Box::new(
response_api,
)))
} }
( (
SupportedUpstreamAPIs::AmazonBedrockConverse(_), SupportedUpstreamAPIs::AmazonBedrockConverse(_),
@ -219,10 +228,15 @@ impl TryFrom<(&[u8], &SupportedAPIsFromClient, &ProviderId)> for ProviderRespons
let response_api: ResponsesAPIResponse = chat_resp.try_into().map_err(|e| { let response_api: ResponsesAPIResponse = chat_resp.try_into().map_err(|e| {
std::io::Error::new( std::io::Error::new(
std::io::ErrorKind::InvalidData, std::io::ErrorKind::InvalidData,
format!("ChatCompletions to ResponsesAPI transformation error: {}", e), format!(
"ChatCompletions to ResponsesAPI transformation error: {}",
e
),
) )
})?; })?;
Ok(ProviderResponseType::ResponsesAPIResponse(response_api)) Ok(ProviderResponseType::ResponsesAPIResponse(Box::new(
response_api,
)))
} }
_ => Err(std::io::Error::new( _ => Err(std::io::Error::new(
std::io::ErrorKind::InvalidData, std::io::ErrorKind::InvalidData,
@ -255,8 +269,8 @@ impl Error for ProviderResponseError {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::apis::openai::OpenAIApi;
use crate::apis::anthropic::AnthropicApi; use crate::apis::anthropic::AnthropicApi;
use crate::apis::openai::OpenAIApi;
use crate::clients::endpoints::SupportedAPIsFromClient; use crate::clients::endpoints::SupportedAPIsFromClient;
use crate::providers::id::ProviderId; use crate::providers::id::ProviderId;
use serde_json::json; use serde_json::json;

View file

@ -1,18 +1,17 @@
use serde::Serialize; use serde::Serialize;
use std::convert::TryFrom; use std::convert::TryFrom;
use crate::apis::amazon_bedrock::ConverseStreamEvent;
use crate::apis::anthropic::MessagesStreamEvent;
use crate::apis::openai::ChatCompletionsStreamResponse; use crate::apis::openai::ChatCompletionsStreamResponse;
use crate::apis::openai_responses::ResponsesAPIStreamEvent; use crate::apis::openai_responses::ResponsesAPIStreamEvent;
use crate::apis::streaming_shapes::sse::SseEvent; use crate::apis::streaming_shapes::sse::SseEvent;
use crate::apis::amazon_bedrock::ConverseStreamEvent;
use crate::apis::anthropic::MessagesStreamEvent;
use crate::apis::streaming_shapes::sse::SseStreamBuffer; use crate::apis::streaming_shapes::sse::SseStreamBuffer;
use crate::apis::streaming_shapes::{ use crate::apis::streaming_shapes::{
anthropic_streaming_buffer::AnthropicMessagesStreamBuffer, anthropic_streaming_buffer::AnthropicMessagesStreamBuffer,
chat_completions_streaming_buffer::OpenAIChatCompletionsStreamBuffer, chat_completions_streaming_buffer::OpenAIChatCompletionsStreamBuffer,
passthrough_streaming_buffer::PassthroughStreamBuffer, passthrough_streaming_buffer::PassthroughStreamBuffer,
responses_api_streaming_buffer::ResponsesAPIStreamBuffer, };
};
use crate::clients::endpoints::SupportedAPIsFromClient; use crate::clients::endpoints::SupportedAPIsFromClient;
use crate::clients::endpoints::SupportedUpstreamAPIs; use crate::clients::endpoints::SupportedUpstreamAPIs;
@ -28,9 +27,18 @@ pub fn needs_buffering(
) -> bool { ) -> bool {
match (client_api, upstream_api) { match (client_api, upstream_api) {
// Same APIs - no buffering needed // Same APIs - no buffering needed
(SupportedAPIsFromClient::OpenAIChatCompletions(_), SupportedUpstreamAPIs::OpenAIChatCompletions(_)) => false, (
(SupportedAPIsFromClient::AnthropicMessagesAPI(_), SupportedUpstreamAPIs::AnthropicMessagesAPI(_)) => false, SupportedAPIsFromClient::OpenAIChatCompletions(_),
(SupportedAPIsFromClient::OpenAIResponsesAPI(_), SupportedUpstreamAPIs::OpenAIResponsesAPI(_)) => false, SupportedUpstreamAPIs::OpenAIChatCompletions(_),
) => false,
(
SupportedAPIsFromClient::AnthropicMessagesAPI(_),
SupportedUpstreamAPIs::AnthropicMessagesAPI(_),
) => false,
(
SupportedAPIsFromClient::OpenAIResponsesAPI(_),
SupportedUpstreamAPIs::OpenAIResponsesAPI(_),
) => false,
// Different APIs - buffering needed // Different APIs - buffering needed
_ => true, _ => true,
@ -53,15 +61,12 @@ pub fn needs_buffering(
/// // Flush to wire /// // Flush to wire
/// let bytes = buffer.into_bytes(); /// let bytes = buffer.into_bytes();
/// ``` /// ```
impl TryFrom<(&SupportedAPIsFromClient, &SupportedUpstreamAPIs)> impl TryFrom<(&SupportedAPIsFromClient, &SupportedUpstreamAPIs)> for SseStreamBuffer {
for SseStreamBuffer
{
type Error = Box<dyn std::error::Error + Send + Sync>; type Error = Box<dyn std::error::Error + Send + Sync>;
fn try_from( fn try_from(
(client_api, upstream_api): (&SupportedAPIsFromClient, &SupportedUpstreamAPIs), (client_api, upstream_api): (&SupportedAPIsFromClient, &SupportedUpstreamAPIs),
) -> Result<Self, Self::Error> { ) -> Result<Self, Self::Error> {
// If APIs match, use passthrough - no buffering/transformation needed // If APIs match, use passthrough - no buffering/transformation needed
if !needs_buffering(client_api, upstream_api) { if !needs_buffering(client_api, upstream_api) {
return Ok(SseStreamBuffer::Passthrough(PassthroughStreamBuffer::new())); return Ok(SseStreamBuffer::Passthrough(PassthroughStreamBuffer::new()));
@ -69,14 +74,14 @@ impl TryFrom<(&SupportedAPIsFromClient, &SupportedUpstreamAPIs)>
// APIs differ - use appropriate buffer for client API // APIs differ - use appropriate buffer for client API
match client_api { match client_api {
SupportedAPIsFromClient::OpenAIChatCompletions(_) => { SupportedAPIsFromClient::OpenAIChatCompletions(_) => Ok(
Ok(SseStreamBuffer::OpenAIChatCompletions(OpenAIChatCompletionsStreamBuffer::new())) SseStreamBuffer::OpenAIChatCompletions(OpenAIChatCompletionsStreamBuffer::new()),
} ),
SupportedAPIsFromClient::AnthropicMessagesAPI(_) => { SupportedAPIsFromClient::AnthropicMessagesAPI(_) => Ok(
Ok(SseStreamBuffer::AnthropicMessages(AnthropicMessagesStreamBuffer::new())) SseStreamBuffer::AnthropicMessages(AnthropicMessagesStreamBuffer::new()),
} ),
SupportedAPIsFromClient::OpenAIResponsesAPI(_) => { SupportedAPIsFromClient::OpenAIResponsesAPI(_) => {
Ok(SseStreamBuffer::OpenAIResponses(ResponsesAPIStreamBuffer::new())) Ok(SseStreamBuffer::OpenAIResponses(Box::default()))
} }
} }
} }
@ -88,11 +93,12 @@ impl TryFrom<(&SupportedAPIsFromClient, &SupportedUpstreamAPIs)>
#[derive(Serialize, Debug, Clone)] #[derive(Serialize, Debug, Clone)]
#[serde(untagged)] #[serde(untagged)]
#[allow(clippy::large_enum_variant)]
pub enum ProviderStreamResponseType { pub enum ProviderStreamResponseType {
ChatCompletionsStreamResponse(ChatCompletionsStreamResponse), ChatCompletionsStreamResponse(ChatCompletionsStreamResponse),
MessagesStreamEvent(MessagesStreamEvent), MessagesStreamEvent(MessagesStreamEvent),
ConverseStreamEvent(ConverseStreamEvent), ConverseStreamEvent(ConverseStreamEvent),
ResponseAPIStreamEvent(ResponsesAPIStreamEvent) ResponseAPIStreamEvent(Box<ResponsesAPIStreamEvent>),
} }
pub trait ProviderStreamResponse: Send + Sync { pub trait ProviderStreamResponse: Send + Sync {
@ -145,12 +151,11 @@ impl ProviderStreamResponse for ProviderStreamResponseType {
ProviderStreamResponseType::ResponseAPIStreamEvent(resp) => resp.event_type(), ProviderStreamResponseType::ResponseAPIStreamEvent(resp) => resp.event_type(),
} }
} }
} }
impl Into<String> for ProviderStreamResponseType { impl From<ProviderStreamResponseType> for String {
fn into(self) -> String { fn from(val: ProviderStreamResponseType) -> String {
match self { match val {
ProviderStreamResponseType::MessagesStreamEvent(event) => { ProviderStreamResponseType::MessagesStreamEvent(event) => {
// Use the Into<String> implementation for proper SSE formatting with event lines // Use the Into<String> implementation for proper SSE formatting with event lines
event.into() event.into()
@ -161,27 +166,36 @@ impl Into<String> for ProviderStreamResponseType {
} }
ProviderStreamResponseType::ResponseAPIStreamEvent(event) => { ProviderStreamResponseType::ResponseAPIStreamEvent(event) => {
// Use the Into<String> implementation for proper SSE formatting with event lines // Use the Into<String> implementation for proper SSE formatting with event lines
event.into() // Clone to work around Box<T> ownership
let cloned = (*event).clone();
cloned.into()
} }
ProviderStreamResponseType::ChatCompletionsStreamResponse(_) => { ProviderStreamResponseType::ChatCompletionsStreamResponse(_) => {
// For OpenAI, use simple data line format // For OpenAI, use simple data line format
let json = serde_json::to_string(&self).unwrap_or_default(); let json = serde_json::to_string(&val).unwrap_or_default();
format!("data: {}\n\n", json) format!("data: {}\n\n", json)
} }
} }
} }
} }
// Stream response transformation logic for client API compatibility // Stream response transformation logic for client API compatibility
impl TryFrom<(&[u8], &SupportedAPIsFromClient, &SupportedUpstreamAPIs)> for ProviderStreamResponseType { impl TryFrom<(&[u8], &SupportedAPIsFromClient, &SupportedUpstreamAPIs)>
for ProviderStreamResponseType
{
type Error = Box<dyn std::error::Error + Send + Sync>; type Error = Box<dyn std::error::Error + Send + Sync>;
fn try_from( fn try_from(
(bytes, client_api, upstream_api): (&[u8], &SupportedAPIsFromClient, &SupportedUpstreamAPIs), (bytes, client_api, upstream_api): (
&[u8],
&SupportedAPIsFromClient,
&SupportedUpstreamAPIs,
),
) -> Result<Self, Self::Error> { ) -> Result<Self, Self::Error> {
// Special case: Handle [DONE] marker for OpenAI -> Anthropic conversion // Special case: Handle [DONE] marker for OpenAI -> Anthropic conversion
if bytes == b"[DONE]" && matches!(client_api, SupportedAPIsFromClient::AnthropicMessagesAPI(_)) { if bytes == b"[DONE]"
&& matches!(client_api, SupportedAPIsFromClient::AnthropicMessagesAPI(_))
{
return Ok(ProviderStreamResponseType::MessagesStreamEvent( return Ok(ProviderStreamResponseType::MessagesStreamEvent(
crate::apis::anthropic::MessagesStreamEvent::MessageStop, crate::apis::anthropic::MessagesStreamEvent::MessageStop,
)); ));
@ -214,9 +228,9 @@ impl TryFrom<(&[u8], &SupportedAPIsFromClient, &SupportedUpstreamAPIs)> for Prov
) => { ) => {
let openai_resp: crate::apis::openai::ChatCompletionsStreamResponse = let openai_resp: crate::apis::openai::ChatCompletionsStreamResponse =
serde_json::from_slice(bytes)?; serde_json::from_slice(bytes)?;
let responses_resp = openai_resp.try_into()?; let responses_resp: ResponsesAPIStreamEvent = openai_resp.try_into()?;
Ok(ProviderStreamResponseType::ResponseAPIStreamEvent( Ok(ProviderStreamResponseType::ResponseAPIStreamEvent(
responses_resp, Box::new(responses_resp),
)) ))
} }
@ -267,10 +281,11 @@ impl TryFrom<(&[u8], &SupportedAPIsFromClient, &SupportedUpstreamAPIs)> for Prov
// Chain: Bedrock -> ChatCompletions -> ResponsesAPI // Chain: Bedrock -> ChatCompletions -> ResponsesAPI
let bedrock_resp: crate::apis::amazon_bedrock::ConverseStreamEvent = let bedrock_resp: crate::apis::amazon_bedrock::ConverseStreamEvent =
serde_json::from_slice(bytes)?; serde_json::from_slice(bytes)?;
let chat_resp: crate::apis::openai::ChatCompletionsStreamResponse = bedrock_resp.try_into()?; let chat_resp: crate::apis::openai::ChatCompletionsStreamResponse =
let responses_resp = chat_resp.try_into()?; bedrock_resp.try_into()?;
let responses_resp: ResponsesAPIStreamEvent = chat_resp.try_into()?;
Ok(ProviderStreamResponseType::ResponseAPIStreamEvent( Ok(ProviderStreamResponseType::ResponseAPIStreamEvent(
responses_resp, Box::new(responses_resp),
)) ))
} }
_ => Err(std::io::Error::new( _ => Err(std::io::Error::new(
@ -287,7 +302,11 @@ impl TryFrom<(SseEvent, &SupportedAPIsFromClient, &SupportedUpstreamAPIs)> for S
type Error = Box<dyn std::error::Error + Send + Sync>; type Error = Box<dyn std::error::Error + Send + Sync>;
fn try_from( fn try_from(
(sse_event, client_api, upstream_api): (SseEvent, &SupportedAPIsFromClient, &SupportedUpstreamAPIs), (sse_event, client_api, upstream_api): (
SseEvent,
&SupportedAPIsFromClient,
&SupportedUpstreamAPIs,
),
) -> Result<Self, Self::Error> { ) -> Result<Self, Self::Error> {
// Create a new transformed event based on the original // Create a new transformed event based on the original
let mut transformed_event = sse_event; let mut transformed_event = sse_event;
@ -296,7 +315,11 @@ impl TryFrom<(SseEvent, &SupportedAPIsFromClient, &SupportedUpstreamAPIs)> for S
if transformed_event.is_done() { if transformed_event.is_done() {
// For OpenAI client APIs (ChatCompletions and ResponsesAPI), keep [DONE] as-is // For OpenAI client APIs (ChatCompletions and ResponsesAPI), keep [DONE] as-is
// For Anthropic client API, it will be transformed via ProviderStreamResponseType // For Anthropic client API, it will be transformed via ProviderStreamResponseType
if matches!(client_api, SupportedAPIsFromClient::OpenAIChatCompletions(_) | SupportedAPIsFromClient::OpenAIResponsesAPI(_)) { if matches!(
client_api,
SupportedAPIsFromClient::OpenAIChatCompletions(_)
| SupportedAPIsFromClient::OpenAIResponsesAPI(_)
) {
// Keep the [DONE] marker as-is for OpenAI clients // Keep the [DONE] marker as-is for OpenAI clients
transformed_event.sse_transformed_lines = "data: [DONE]".to_string(); transformed_event.sse_transformed_lines = "data: [DONE]".to_string();
return Ok(transformed_event); return Ok(transformed_event);
@ -328,7 +351,7 @@ impl TryFrom<(SseEvent, &SupportedAPIsFromClient, &SupportedUpstreamAPIs)> for S
// OpenAI clients don't expect separate event: lines // OpenAI clients don't expect separate event: lines
// Suppress upstream Anthropic event-only lines // Suppress upstream Anthropic event-only lines
if transformed_event.is_event_only() && transformed_event.event.is_some() { if transformed_event.is_event_only() && transformed_event.event.is_some() {
transformed_event.sse_transformed_lines = format!("\n"); transformed_event.sse_transformed_lines = "\n".to_string();
} }
} }
_ => { _ => {
@ -345,7 +368,8 @@ impl TryFrom<(SseEvent, &SupportedAPIsFromClient, &SupportedUpstreamAPIs)> for S
( (
SupportedAPIsFromClient::AnthropicMessagesAPI(_), SupportedAPIsFromClient::AnthropicMessagesAPI(_),
SupportedUpstreamAPIs::AnthropicMessagesAPI(_), SupportedUpstreamAPIs::AnthropicMessagesAPI(_),
) | ( )
| (
SupportedAPIsFromClient::OpenAIResponsesAPI(_), SupportedAPIsFromClient::OpenAIResponsesAPI(_),
SupportedUpstreamAPIs::OpenAIResponsesAPI(_), SupportedUpstreamAPIs::OpenAIResponsesAPI(_),
) => { ) => {
@ -415,7 +439,7 @@ impl
openai_event, openai_event,
)) ))
} }
( (
SupportedUpstreamAPIs::AmazonBedrockConverseStream(_), SupportedUpstreamAPIs::AmazonBedrockConverseStream(_),
SupportedAPIsFromClient::OpenAIResponsesAPI(_), SupportedAPIsFromClient::OpenAIResponsesAPI(_),
) => { ) => {
@ -428,7 +452,7 @@ impl
openai_chat_completions_event.try_into()?; openai_chat_completions_event.try_into()?;
Ok(ProviderStreamResponseType::ResponseAPIStreamEvent( Ok(ProviderStreamResponseType::ResponseAPIStreamEvent(
openai_responses_api_event, Box::new(openai_responses_api_event),
)) ))
} }
_ => Err("Unsupported API combination for event-stream decoding".into()), _ => Err("Unsupported API combination for event-stream decoding".into()),
@ -445,11 +469,11 @@ impl
mod tests { mod tests {
use super::*; use super::*;
use crate::apis::streaming_shapes::amazon_bedrock_binary_frame::BedrockBinaryFrameDecoder; use crate::apis::streaming_shapes::amazon_bedrock_binary_frame::BedrockBinaryFrameDecoder;
use crate::clients::endpoints::SupportedAPIsFromClient;
use crate::apis::streaming_shapes::sse::SseStreamIter; use crate::apis::streaming_shapes::sse::SseStreamIter;
use crate::clients::endpoints::SupportedAPIsFromClient;
use serde_json::json; use serde_json::json;
#[test] #[test]
fn test_sse_event_parsing() { fn test_sse_event_parsing() {
// Test valid SSE data line // Test valid SSE data line
let line = "data: {\"id\":\"test\",\"object\":\"chat.completion.chunk\"}\n\n"; let line = "data: {\"id\":\"test\",\"object\":\"chat.completion.chunk\"}\n\n";
@ -792,7 +816,7 @@ mod tests {
// Simulate chunked network arrivals with realistic chunk sizes // Simulate chunked network arrivals with realistic chunk sizes
// Using varying chunk sizes to test partial frame handling // Using varying chunk sizes to test partial frame handling
let mut buffer = BytesMut::new(); let mut buffer = BytesMut::new();
let chunk_size_pattern = vec![500, 1000, 750, 1200, 800, 1500]; let chunk_size_pattern = [500, 1000, 750, 1200, 800, 1500];
let mut offset = 0; let mut offset = 0;
let mut total_frames = 0; let mut total_frames = 0;
let mut chunk_num = 0; let mut chunk_num = 0;
@ -837,7 +861,7 @@ mod tests {
); );
} }
#[test] #[test]
fn test_bedrock_decoded_frame_to_provider_response() { fn test_bedrock_decoded_frame_to_provider_response() {
test_bedrock_conversion(false); test_bedrock_conversion(false);
} }
@ -879,8 +903,9 @@ mod tests {
let mut decoder = BedrockBinaryFrameDecoder::new(&mut buffer); let mut decoder = BedrockBinaryFrameDecoder::new(&mut buffer);
let client_api = let client_api = SupportedAPIsFromClient::AnthropicMessagesAPI(
SupportedAPIsFromClient::AnthropicMessagesAPI(crate::apis::anthropic::AnthropicApi::Messages); crate::apis::anthropic::AnthropicApi::Messages,
);
let upstream_api = SupportedUpstreamAPIs::AmazonBedrockConverseStream( let upstream_api = SupportedUpstreamAPIs::AmazonBedrockConverseStream(
crate::apis::amazon_bedrock::AmazonBedrockApi::ConverseStream, crate::apis::amazon_bedrock::AmazonBedrockApi::ConverseStream,
); );
@ -966,8 +991,9 @@ mod tests {
let mut decoder = BedrockBinaryFrameDecoder::new(&mut buffer); let mut decoder = BedrockBinaryFrameDecoder::new(&mut buffer);
let client_api = let client_api = SupportedAPIsFromClient::AnthropicMessagesAPI(
SupportedAPIsFromClient::AnthropicMessagesAPI(crate::apis::anthropic::AnthropicApi::Messages); crate::apis::anthropic::AnthropicApi::Messages,
);
let upstream_api = SupportedUpstreamAPIs::AmazonBedrockConverseStream( let upstream_api = SupportedUpstreamAPIs::AmazonBedrockConverseStream(
crate::apis::amazon_bedrock::AmazonBedrockApi::ConverseStream, crate::apis::amazon_bedrock::AmazonBedrockApi::ConverseStream,
); );
@ -1051,7 +1077,6 @@ mod tests {
); );
} }
#[test] #[test]
fn test_sse_event_transformation_openai_to_anthropic_message_delta() { fn test_sse_event_transformation_openai_to_anthropic_message_delta() {
use crate::apis::anthropic::AnthropicApi; use crate::apis::anthropic::AnthropicApi;
@ -1079,8 +1104,8 @@ mod tests {
let sse_event = SseEvent { let sse_event = SseEvent {
data: Some(openai_stream_chunk.to_string()), data: Some(openai_stream_chunk.to_string()),
event: None, event: None,
raw_line: format!("data: {}", openai_stream_chunk.to_string()), raw_line: format!("data: {}", openai_stream_chunk),
sse_transformed_lines: format!("data: {}", openai_stream_chunk.to_string()), sse_transformed_lines: format!("data: {}", openai_stream_chunk),
provider_stream_response: None, provider_stream_response: None,
}; };
@ -1101,7 +1126,8 @@ mod tests {
// Verify the event was transformed to Anthropic format // Verify the event was transformed to Anthropic format
// This should contain message_delta with stop_reason and usage // This should contain message_delta with stop_reason and usage
assert!( assert!(
buffer.contains("event: message_delta") || buffer.contains("\"type\":\"message_delta\""), buffer.contains("event: message_delta")
|| buffer.contains("\"type\":\"message_delta\""),
"Should contain message_delta in transformed event" "Should contain message_delta in transformed event"
); );
@ -1134,8 +1160,8 @@ mod tests {
let sse_event = SseEvent { let sse_event = SseEvent {
data: Some(openai_stream_chunk.to_string()), data: Some(openai_stream_chunk.to_string()),
event: None, event: None,
raw_line: format!("data: {}", openai_stream_chunk.to_string()), raw_line: format!("data: {}", openai_stream_chunk),
sse_transformed_lines: format!("data: {}", openai_stream_chunk.to_string()), sse_transformed_lines: format!("data: {}", openai_stream_chunk),
provider_stream_response: None, provider_stream_response: None,
}; };
@ -1223,8 +1249,8 @@ mod tests {
let sse_event = SseEvent { let sse_event = SseEvent {
data: Some(anthropic_event.to_string()), data: Some(anthropic_event.to_string()),
event: None, event: None,
raw_line: format!("data: {}", anthropic_event.to_string()), raw_line: format!("data: {}", anthropic_event),
sse_transformed_lines: format!("data: {}", anthropic_event.to_string()), sse_transformed_lines: format!("data: {}", anthropic_event),
provider_stream_response: None, provider_stream_response: None,
}; };
@ -1314,8 +1340,8 @@ mod tests {
let sse_event = SseEvent { let sse_event = SseEvent {
data: Some(openai_stream_chunk.to_string()), data: Some(openai_stream_chunk.to_string()),
event: None, event: None,
raw_line: format!("data: {}", openai_stream_chunk.to_string()), raw_line: format!("data: {}", openai_stream_chunk),
sse_transformed_lines: format!("data: {}", openai_stream_chunk.to_string()), sse_transformed_lines: format!("data: {}", openai_stream_chunk),
provider_stream_response: None, provider_stream_response: None,
}; };

View file

@ -11,11 +11,11 @@ pub trait ExtractText {
/// Trait for utility functions on content collections /// Trait for utility functions on content collections
pub trait ContentUtils<T> { pub trait ContentUtils<T> {
fn extract_tool_calls(&self) -> Result<Option<Vec<ToolCall>>, TransformError>; fn extract_tool_calls(&self) -> Result<Option<Vec<ToolCall>>, TransformError>;
fn split_for_openai( fn split_for_openai(&self) -> Result<SplitForOpenAIResult, TransformError>;
&self,
) -> Result<(Vec<ContentPart>, Vec<ToolCall>, Vec<(String, String, bool)>), TransformError>;
} }
pub type SplitForOpenAIResult = (Vec<ContentPart>, Vec<ToolCall>, Vec<(String, String, bool)>);
/// Helper to create a current unix timestamp /// Helper to create a current unix timestamp
pub fn current_timestamp() -> u64 { pub fn current_timestamp() -> u64 {
SystemTime::now() SystemTime::now()

View file

@ -38,7 +38,7 @@ impl TryFrom<AnthropicMessagesRequest> for ChatCompletionsRequest {
} }
// Convert tools and tool choice // Convert tools and tool choice
let openai_tools = req.tools.map(|tools| convert_anthropic_tools(tools)); let openai_tools = req.tools.map(convert_anthropic_tools);
let (openai_tool_choice, parallel_tool_calls) = let (openai_tool_choice, parallel_tool_calls) =
convert_anthropic_tool_choice(req.tool_choice); convert_anthropic_tool_choice(req.tool_choice);
@ -218,18 +218,18 @@ impl TryFrom<MessagesMessage> for Vec<Message> {
} }
// Role Conversions // Role Conversions
impl Into<Role> for MessagesRole { impl From<MessagesRole> for Role {
fn into(self) -> Role { fn from(val: MessagesRole) -> Self {
match self { match val {
MessagesRole::User => Role::User, MessagesRole::User => Role::User,
MessagesRole::Assistant => Role::Assistant, MessagesRole::Assistant => Role::Assistant,
} }
} }
} }
impl Into<MessagesStopReason> for FinishReason { impl From<FinishReason> for MessagesStopReason {
fn into(self) -> MessagesStopReason { fn from(val: FinishReason) -> Self {
match self { match val {
FinishReason::Stop => MessagesStopReason::EndTurn, FinishReason::Stop => MessagesStopReason::EndTurn,
FinishReason::Length => MessagesStopReason::MaxTokens, FinishReason::Length => MessagesStopReason::MaxTokens,
FinishReason::ToolCalls => MessagesStopReason::ToolUse, FinishReason::ToolCalls => MessagesStopReason::ToolUse,
@ -239,11 +239,11 @@ impl Into<MessagesStopReason> for FinishReason {
} }
} }
impl Into<MessagesUsage> for Usage { impl From<Usage> for MessagesUsage {
fn into(self) -> MessagesUsage { fn from(val: Usage) -> Self {
MessagesUsage { MessagesUsage {
input_tokens: self.prompt_tokens, input_tokens: val.prompt_tokens,
output_tokens: self.completion_tokens, output_tokens: val.completion_tokens,
cache_creation_input_tokens: None, cache_creation_input_tokens: None,
cache_read_input_tokens: None, cache_read_input_tokens: None,
} }
@ -251,9 +251,9 @@ impl Into<MessagesUsage> for Usage {
} }
// System Prompt Conversions // System Prompt Conversions
impl Into<Message> for MessagesSystemPrompt { impl From<MessagesSystemPrompt> for Message {
fn into(self) -> Message { fn from(val: MessagesSystemPrompt) -> Self {
let system_content = match self { let system_content = match val {
MessagesSystemPrompt::Single(text) => MessageContent::Text(text), MessagesSystemPrompt::Single(text) => MessageContent::Text(text),
MessagesSystemPrompt::Blocks(blocks) => MessageContent::Text(blocks.extract_text()), MessagesSystemPrompt::Blocks(blocks) => MessageContent::Text(blocks.extract_text()),
}; };
@ -384,12 +384,8 @@ impl TryFrom<MessagesMessage> for BedrockMessage {
ToolResultContent::Blocks(blocks) => { ToolResultContent::Blocks(blocks) => {
let mut result_blocks = Vec::new(); let mut result_blocks = Vec::new();
for result_block in blocks { for result_block in blocks {
match result_block { if let crate::apis::anthropic::MessagesContentBlock::Text { text, .. } = result_block {
crate::apis::anthropic::MessagesContentBlock::Text { text, .. } => { result_blocks.push(ToolResultContentBlock::Text { text });
result_blocks.push(ToolResultContentBlock::Text { text });
}
// For now, skip other content types in tool results
_ => {}
} }
} }
result_blocks result_blocks

View file

@ -14,7 +14,8 @@ use crate::apis::openai::{
}; };
use crate::apis::openai_responses::{ use crate::apis::openai_responses::{
ResponsesAPIRequest, InputContent, InputItem, InputParam, MessageRole, Modality, ReasoningEffort, Tool as ResponsesTool, ToolChoice as ResponsesToolChoice InputContent, InputItem, InputParam, MessageRole, Modality, ReasoningEffort,
ResponsesAPIRequest, Tool as ResponsesTool, ToolChoice as ResponsesToolChoice,
}; };
use crate::clients::TransformError; use crate::clients::TransformError;
use crate::transforms::lib::ExtractText; use crate::transforms::lib::ExtractText;
@ -27,9 +28,9 @@ type AnthropicMessagesRequest = MessagesRequest;
// MAIN REQUEST TRANSFORMATIONS // MAIN REQUEST TRANSFORMATIONS
// ============================================================================ // ============================================================================
impl Into<MessagesSystemPrompt> for Message { impl From<Message> for MessagesSystemPrompt {
fn into(self) -> MessagesSystemPrompt { fn from(val: Message) -> Self {
let system_text = match self.content { let system_text = match val.content {
MessageContent::Text(text) => text, MessageContent::Text(text) => text,
MessageContent::Parts(parts) => parts.extract_text(), MessageContent::Parts(parts) => parts.extract_text(),
}; };
@ -163,7 +164,7 @@ impl TryFrom<Message> for BedrockMessage {
let has_tool_calls = message let has_tool_calls = message
.tool_calls .tool_calls
.as_ref() .as_ref()
.map_or(false, |calls| !calls.is_empty()); .is_some_and(|calls| !calls.is_empty());
// Add text content if it's non-empty, or if we have no tool calls (to avoid empty content) // Add text content if it's non-empty, or if we have no tool calls (to avoid empty content)
if !text_content.is_empty() { if !text_content.is_empty() {
@ -252,7 +253,6 @@ impl TryFrom<ResponsesAPIRequest> for ChatCompletionsRequest {
type Error = TransformError; type Error = TransformError;
fn try_from(req: ResponsesAPIRequest) -> Result<Self, Self::Error> { fn try_from(req: ResponsesAPIRequest) -> Result<Self, Self::Error> {
// Convert input to messages // Convert input to messages
let messages = match req.input { let messages = match req.input {
InputParam::Text(text) => { InputParam::Text(text) => {
@ -282,50 +282,27 @@ impl TryFrom<ResponsesAPIRequest> for ChatCompletionsRequest {
// Convert each input item // Convert each input item
for item in items { for item in items {
match item { if let InputItem::Message(input_msg) = item {
InputItem::Message(input_msg) => { let role = match input_msg.role {
let role = match input_msg.role { MessageRole::User => Role::User,
MessageRole::User => Role::User, MessageRole::Assistant => Role::Assistant,
MessageRole::Assistant => Role::Assistant, MessageRole::System => Role::System,
MessageRole::System => Role::System, MessageRole::Developer => Role::System, // Map developer to system
MessageRole::Developer => Role::System, // Map developer to system };
};
// Convert content based on MessageContent type // Convert content based on MessageContent type
let content = match &input_msg.content { let content = match &input_msg.content {
crate::apis::openai_responses::MessageContent::Text(text) => { crate::apis::openai_responses::MessageContent::Text(text) => {
// Simple text content // Simple text content
MessageContent::Text(text.clone()) MessageContent::Text(text.clone())
} }
crate::apis::openai_responses::MessageContent::Items(content_items) => { crate::apis::openai_responses::MessageContent::Items(content_items) => {
// Check if it's a single text item (can use simple text format) // Check if it's a single text item (can use simple text format)
if content_items.len() == 1 { if content_items.len() == 1 {
if let InputContent::InputText { text } = &content_items[0] { if let InputContent::InputText { text } = &content_items[0] {
MessageContent::Text(text.clone()) MessageContent::Text(text.clone())
} else {
// Single non-text item - use parts format
MessageContent::Parts(
content_items.iter()
.filter_map(|c| match c {
InputContent::InputText { text } => {
Some(crate::apis::openai::ContentPart::Text { text: text.clone() })
}
InputContent::InputImage { image_url, .. } => {
Some(crate::apis::openai::ContentPart::ImageUrl {
image_url: crate::apis::openai::ImageUrl {
url: image_url.clone(),
detail: None,
}
})
}
InputContent::InputFile { .. } => None, // Skip files for now
InputContent::InputAudio { .. } => None, // Skip audio for now
})
.collect()
)
}
} else { } else {
// Multiple content items - convert to parts // Single non-text item - use parts format
MessageContent::Parts( MessageContent::Parts(
content_items.iter() content_items.iter()
.filter_map(|c| match c { .filter_map(|c| match c {
@ -346,20 +323,41 @@ impl TryFrom<ResponsesAPIRequest> for ChatCompletionsRequest {
.collect() .collect()
) )
} }
} else {
// Multiple content items - convert to parts
MessageContent::Parts(
content_items
.iter()
.filter_map(|c| match c {
InputContent::InputText { text } => {
Some(crate::apis::openai::ContentPart::Text {
text: text.clone(),
})
}
InputContent::InputImage { image_url, .. } => Some(
crate::apis::openai::ContentPart::ImageUrl {
image_url: crate::apis::openai::ImageUrl {
url: image_url.clone(),
detail: None,
},
},
),
InputContent::InputFile { .. } => None, // Skip files for now
InputContent::InputAudio { .. } => None, // Skip audio for now
})
.collect(),
)
} }
}; }
};
converted_messages.push(Message { converted_messages.push(Message {
role, role,
content, content,
name: None, name: None,
tool_call_id: None, tool_call_id: None,
tool_calls: None, tool_calls: None,
}); });
}
// Skip non-message items (references, outputs) for now
// These would need special handling in chat completions format
_ => {}
} }
} }
@ -474,7 +472,7 @@ impl TryFrom<ChatCompletionsRequest> for AnthropicMessagesRequest {
} }
// Convert tools and tool choice // Convert tools and tool choice
let anthropic_tools = req.tools.map(|tools| convert_openai_tools(tools)); let anthropic_tools = req.tools.map(convert_openai_tools);
let anthropic_tool_choice = let anthropic_tool_choice =
convert_openai_tool_choice(req.tool_choice, req.parallel_tool_calls); convert_openai_tool_choice(req.tool_choice, req.parallel_tool_calls);

View file

@ -13,18 +13,14 @@ use crate::apis::openai_responses::{
pub fn convert_responses_output_to_input_items(output: &OutputItem) -> Option<InputItem> { pub fn convert_responses_output_to_input_items(output: &OutputItem) -> Option<InputItem> {
match output { match output {
// Convert output messages to input messages // Convert output messages to input messages
OutputItem::Message { OutputItem::Message { role, content, .. } => {
role, content, ..
} => {
let input_content: Vec<InputContent> = content let input_content: Vec<InputContent> = content
.iter() .iter()
.filter_map(|c| match c { .filter_map(|c| match c {
OutputContent::OutputText { text, .. } => Some(InputContent::InputText { OutputContent::OutputText { text, .. } => {
text: text.clone(), Some(InputContent::InputText { text: text.clone() })
}), }
OutputContent::OutputAudio { OutputContent::OutputAudio { data, .. } => Some(InputContent::InputAudio {
data, ..
} => Some(InputContent::InputAudio {
data: data.clone(), data: data.clone(),
format: None, // Format not preserved in output format: None, // Format not preserved in output
}), }),
@ -84,7 +80,7 @@ pub fn outputs_to_inputs(outputs: &[OutputItem]) -> Vec<InputItem> {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::apis::openai_responses::{OutputItemStatus}; use crate::apis::openai_responses::OutputItemStatus;
#[test] #[test]
fn test_output_message_to_input() { fn test_output_message_to_input() {
@ -135,14 +131,12 @@ mod tests {
InputItem::Message(msg) => { InputItem::Message(msg) => {
assert!(matches!(msg.role, MessageRole::Assistant)); assert!(matches!(msg.role, MessageRole::Assistant));
match &msg.content { match &msg.content {
MessageContent::Items(items) => { MessageContent::Items(items) => match &items[0] {
match &items[0] { InputContent::InputText { text } => {
InputContent::InputText { text } => { assert!(text.contains("get_weather"));
assert!(text.contains("get_weather"));
}
_ => panic!("Expected InputText"),
} }
} _ => panic!("Expected InputText"),
},
_ => panic!("Expected MessageContent::Items"), _ => panic!("Expected MessageContent::Items"),
} }
} }

View file

@ -1,7 +1,6 @@
use crate::apis::amazon_bedrock::{ConverseOutput, ConverseResponse, StopReason}; use crate::apis::amazon_bedrock::{ConverseOutput, ConverseResponse, StopReason};
use crate::apis::anthropic::{ use crate::apis::anthropic::{
MessagesContentBlock, MessagesResponse, MessagesContentBlock, MessagesResponse, MessagesRole, MessagesStopReason, MessagesUsage,
MessagesRole, MessagesStopReason, MessagesUsage,
}; };
use crate::apis::openai::ChatCompletionsResponse; use crate::apis::openai::ChatCompletionsResponse;
use crate::clients::TransformError; use crate::clients::TransformError;
@ -115,7 +114,6 @@ impl TryFrom<ConverseResponse> for MessagesResponse {
} }
} }
/// Convert Bedrock Message to Anthropic content blocks /// Convert Bedrock Message to Anthropic content blocks
/// ///
/// This function handles the conversion between Amazon Bedrock Converse API format /// This function handles the conversion between Amazon Bedrock Converse API format

View file

@ -1,9 +1,5 @@
use crate::apis::amazon_bedrock::{ use crate::apis::amazon_bedrock::{ConverseOutput, ConverseResponse, StopReason};
ConverseOutput, ConverseResponse, StopReason, use crate::apis::anthropic::{MessagesContentBlock, MessagesResponse, MessagesUsage};
};
use crate::apis::anthropic::{
MessagesContentBlock, MessagesResponse, MessagesUsage,
};
use crate::apis::openai::{ use crate::apis::openai::{
ChatCompletionsResponse, Choice, FinishReason, MessageContent, ResponseMessage, Role, Usage, ChatCompletionsResponse, Choice, FinishReason, MessageContent, ResponseMessage, Role, Usage,
}; };
@ -16,12 +12,12 @@ use crate::transforms::lib::*;
// ============================================================================ // ============================================================================
// Usage Conversions // Usage Conversions
impl Into<Usage> for MessagesUsage { impl From<MessagesUsage> for Usage {
fn into(self) -> Usage { fn from(val: MessagesUsage) -> Self {
Usage { Usage {
prompt_tokens: self.input_tokens, prompt_tokens: val.input_tokens,
completion_tokens: self.output_tokens, completion_tokens: val.output_tokens,
total_tokens: self.input_tokens + self.output_tokens, total_tokens: val.input_tokens + val.output_tokens,
prompt_tokens_details: None, prompt_tokens_details: None,
completion_tokens_details: None, completion_tokens_details: None,
} }
@ -203,7 +199,6 @@ impl TryFrom<ChatCompletionsResponse> for ResponsesAPIResponse {
} }
} }
impl TryFrom<MessagesResponse> for ChatCompletionsResponse { impl TryFrom<MessagesResponse> for ChatCompletionsResponse {
type Error = TransformError; type Error = TransformError;
@ -415,7 +410,6 @@ fn convert_anthropic_content_to_openai(
Ok(MessageContent::Text(text_parts.join("\n"))) Ok(MessageContent::Text(text_parts.join("\n")))
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
@ -994,8 +988,15 @@ mod tests {
let responses_api: ResponsesAPIResponse = chat_response.try_into().unwrap(); let responses_api: ResponsesAPIResponse = chat_response.try_into().unwrap();
// Response ID should be generated with resp_ prefix // Response ID should be generated with resp_ prefix
assert!(responses_api.id.starts_with("resp_"), "Response ID should start with 'resp_'"); assert!(
assert_eq!(responses_api.id.len(), 37, "Response ID should be resp_ + 32 char UUID"); responses_api.id.starts_with("resp_"),
"Response ID should start with 'resp_'"
);
assert_eq!(
responses_api.id.len(),
37,
"Response ID should be resp_ + 32 char UUID"
);
assert_eq!(responses_api.object, "response"); assert_eq!(responses_api.object, "response");
assert_eq!(responses_api.model, "gpt-4"); assert_eq!(responses_api.model, "gpt-4");
@ -1008,11 +1009,7 @@ mod tests {
// Check output items // Check output items
assert_eq!(responses_api.output.len(), 1); assert_eq!(responses_api.output.len(), 1);
match &responses_api.output[0] { match &responses_api.output[0] {
OutputItem::Message { OutputItem::Message { role, content, .. } => {
role,
content,
..
} => {
assert_eq!(role, "assistant"); assert_eq!(role, "assistant");
assert_eq!(content.len(), 1); assert_eq!(content.len(), 1);
match &content[0] { match &content[0] {
@ -1163,6 +1160,9 @@ mod tests {
} }
// Verify status is Completed for tool_calls finish reason // Verify status is Completed for tool_calls finish reason
assert!(matches!(responses_api.status, crate::apis::openai_responses::ResponseStatus::Completed)); assert!(matches!(
responses_api.status,
crate::apis::openai_responses::ResponseStatus::Completed
));
} }
} }

View file

@ -1,12 +1,9 @@
use crate::apis::amazon_bedrock::{ use crate::apis::amazon_bedrock::{ContentBlockDelta, ConverseStreamEvent};
ContentBlockDelta, ConverseStreamEvent,
};
use crate::apis::anthropic::{ use crate::apis::anthropic::{
MessagesContentBlock, MessagesContentDelta, MessagesMessageDelta, MessagesContentBlock, MessagesContentDelta, MessagesMessageDelta, MessagesRole,
MessagesRole, MessagesStopReason, MessagesStreamEvent, MessagesStreamMessage, MessagesUsage, MessagesStopReason, MessagesStreamEvent, MessagesStreamMessage, MessagesUsage,
};
use crate::apis::openai::{ ChatCompletionsStreamResponse, ToolCallDelta,
}; };
use crate::apis::openai::{ChatCompletionsStreamResponse, ToolCallDelta};
use crate::clients::TransformError; use crate::clients::TransformError;
use serde_json::Value; use serde_json::Value;
@ -86,10 +83,10 @@ impl TryFrom<ChatCompletionsStreamResponse> for MessagesStreamEvent {
} }
} }
impl Into<String> for MessagesStreamEvent { impl From<MessagesStreamEvent> for String {
fn into(self) -> String { fn from(val: MessagesStreamEvent) -> Self {
let transformed_json = serde_json::to_string(&self).unwrap_or_default(); let transformed_json = serde_json::to_string(&val).unwrap_or_default();
let event_type = match &self { let event_type = match &val {
MessagesStreamEvent::MessageStart { .. } => "message_start", MessagesStreamEvent::MessageStart { .. } => "message_start",
MessagesStreamEvent::ContentBlockStart { .. } => "content_block_start", MessagesStreamEvent::ContentBlockStart { .. } => "content_block_start",
MessagesStreamEvent::ContentBlockDelta { .. } => "content_block_delta", MessagesStreamEvent::ContentBlockDelta { .. } => "content_block_delta",
@ -194,10 +191,18 @@ impl TryFrom<ConverseStreamEvent> for MessagesStreamEvent {
let anthropic_stop_reason = match stop_event.stop_reason { let anthropic_stop_reason = match stop_event.stop_reason {
crate::apis::amazon_bedrock::StopReason::EndTurn => MessagesStopReason::EndTurn, crate::apis::amazon_bedrock::StopReason::EndTurn => MessagesStopReason::EndTurn,
crate::apis::amazon_bedrock::StopReason::ToolUse => MessagesStopReason::ToolUse, crate::apis::amazon_bedrock::StopReason::ToolUse => MessagesStopReason::ToolUse,
crate::apis::amazon_bedrock::StopReason::MaxTokens => MessagesStopReason::MaxTokens, crate::apis::amazon_bedrock::StopReason::MaxTokens => {
crate::apis::amazon_bedrock::StopReason::StopSequence => MessagesStopReason::EndTurn, MessagesStopReason::MaxTokens
crate::apis::amazon_bedrock::StopReason::GuardrailIntervened => MessagesStopReason::Refusal, }
crate::apis::amazon_bedrock::StopReason::ContentFiltered => MessagesStopReason::Refusal, crate::apis::amazon_bedrock::StopReason::StopSequence => {
MessagesStopReason::EndTurn
}
crate::apis::amazon_bedrock::StopReason::GuardrailIntervened => {
MessagesStopReason::Refusal
}
crate::apis::amazon_bedrock::StopReason::ContentFiltered => {
MessagesStopReason::Refusal
}
}; };
Ok(MessagesStreamEvent::MessageDelta { Ok(MessagesStreamEvent::MessageDelta {

View file

@ -1,8 +1,10 @@
use crate::apis::amazon_bedrock::{ ConverseStreamEvent, StopReason}; use crate::apis::amazon_bedrock::{ConverseStreamEvent, StopReason};
use crate::apis::anthropic::{ use crate::apis::anthropic::{
MessagesContentBlock, MessagesContentDelta, MessagesStopReason, MessagesStreamEvent}; MessagesContentBlock, MessagesContentDelta, MessagesStopReason, MessagesStreamEvent,
use crate::apis::openai::{ ChatCompletionsStreamResponse,FinishReason, };
FunctionCallDelta, MessageDelta, Role, StreamChoice, ToolCallDelta, Usage, use crate::apis::openai::{
ChatCompletionsStreamResponse, FinishReason, FunctionCallDelta, MessageDelta, Role,
StreamChoice, ToolCallDelta, Usage,
}; };
use crate::apis::openai_responses::ResponsesAPIStreamEvent; use crate::apis::openai_responses::ResponsesAPIStreamEvent;
@ -58,11 +60,14 @@ impl TryFrom<MessagesStreamEvent> for ChatCompletionsStreamResponse {
None, None,
)), )),
MessagesStreamEvent::ContentBlockStart { content_block, index } => { MessagesStreamEvent::ContentBlockStart {
convert_content_block_start(content_block, index) content_block,
} index,
} => convert_content_block_start(content_block, index),
MessagesStreamEvent::ContentBlockDelta { delta, index } => convert_content_delta(delta, index), MessagesStreamEvent::ContentBlockDelta { delta, index } => {
convert_content_delta(delta, index)
}
MessagesStreamEvent::ContentBlockStop { .. } => Ok(create_empty_openai_chunk()), MessagesStreamEvent::ContentBlockStop { .. } => Ok(create_empty_openai_chunk()),
@ -427,9 +432,9 @@ fn create_empty_openai_chunk() -> ChatCompletionsStreamResponse {
} }
// Stop Reason Conversions // Stop Reason Conversions
impl Into<FinishReason> for MessagesStopReason { impl From<MessagesStopReason> for FinishReason {
fn into(self) -> FinishReason { fn from(val: MessagesStopReason) -> Self {
match self { match val {
MessagesStopReason::EndTurn => FinishReason::Stop, MessagesStopReason::EndTurn => FinishReason::Stop,
MessagesStopReason::MaxTokens => FinishReason::Length, MessagesStopReason::MaxTokens => FinishReason::Length,
MessagesStopReason::StopSequence => FinishReason::Stop, MessagesStopReason::StopSequence => FinishReason::Stop,
@ -456,34 +461,37 @@ impl TryFrom<ChatCompletionsStreamResponse> for ResponsesAPIStreamEvent {
if let Some(tool_call) = tool_calls.first() { if let Some(tool_call) = tool_calls.first() {
// Extract call_id and name if available (metadata from initial event) // Extract call_id and name if available (metadata from initial event)
let call_id = tool_call.id.clone(); let call_id = tool_call.id.clone();
let function_name = tool_call.function.as_ref() let function_name = tool_call.function.as_ref().and_then(|f| f.name.clone());
.and_then(|f| f.name.clone());
// Check if we have function metadata (name, id) // Check if we have function metadata (name, id)
if let Some(function) = &tool_call.function { if let Some(function) = &tool_call.function {
// If we have arguments delta, return that // If we have arguments delta, return that
if let Some(args) = &function.arguments { if let Some(args) = &function.arguments {
return Ok(ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDelta { return Ok(
output_index: choice.index as i32, ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDelta {
item_id: "".to_string(), // Buffer will fill this output_index: choice.index as i32,
delta: args.clone(), item_id: "".to_string(), // Buffer will fill this
sequence_number: 0, // Buffer will fill this delta: args.clone(),
call_id, sequence_number: 0, // Buffer will fill this
name: function_name, call_id,
}); name: function_name,
},
);
} }
// If we have function name but no arguments yet (initial tool call event) // If we have function name but no arguments yet (initial tool call event)
// Return an empty arguments delta so the buffer knows to create the item // Return an empty arguments delta so the buffer knows to create the item
if function.name.is_some() { if function.name.is_some() {
return Ok(ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDelta { return Ok(
output_index: choice.index as i32, ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDelta {
item_id: "".to_string(), // Buffer will fill this output_index: choice.index as i32,
delta: "".to_string(), // Empty delta signals this is the initial event item_id: "".to_string(), // Buffer will fill this
sequence_number: 0, // Buffer will fill this delta: "".to_string(), // Empty delta signals this is the initial event
call_id, sequence_number: 0, // Buffer will fill this
name: function_name, call_id,
}); name: function_name,
},
);
} }
} }
} }

View file

@ -94,8 +94,8 @@ impl StreamContext {
fn request_identifier(&self) -> String { fn request_identifier(&self) -> String {
self.request_id self.request_id
.as_ref() .as_ref()
.filter(|id| !id.is_empty()) // Filter out empty strings .filter(|id| !id.is_empty())
.map(|id| id.clone()) .cloned()
.unwrap_or_else(|| "NO_REQUEST_ID".to_string()) .unwrap_or_else(|| "NO_REQUEST_ID".to_string())
} }
fn llm_provider(&self) -> &LlmProvider { fn llm_provider(&self) -> &LlmProvider {
@ -504,7 +504,7 @@ impl StreamContext {
// Get accumulated bytes from buffer and return // Get accumulated bytes from buffer and return
match self.sse_buffer.as_mut() { match self.sse_buffer.as_mut() {
Some(buffer) => { Some(buffer) => {
let bytes = buffer.into_bytes(); let bytes = buffer.to_bytes();
if !bytes.is_empty() { if !bytes.is_empty() {
let content = String::from_utf8_lossy(&bytes); let content = String::from_utf8_lossy(&bytes);
debug!( debug!(
@ -623,7 +623,7 @@ impl StreamContext {
// Get accumulated bytes from buffer and return // Get accumulated bytes from buffer and return
match self.sse_buffer.as_mut() { match self.sse_buffer.as_mut() {
Some(buffer) => { Some(buffer) => {
let bytes = buffer.into_bytes(); let bytes = buffer.to_bytes();
if !bytes.is_empty() { if !bytes.is_empty() {
let content = String::from_utf8_lossy(&bytes); let content = String::from_utf8_lossy(&bytes);
debug!( debug!(

View file

@ -142,8 +142,7 @@ impl HttpContext for StreamContext {
let last_user_prompt = match deserialized_body let last_user_prompt = match deserialized_body
.messages .messages
.iter() .iter()
.filter(|msg| msg.role == USER_ROLE) .rfind(|msg| msg.role == USER_ROLE)
.last()
{ {
Some(content) => content, Some(content) => content,
None => { None => {
@ -155,11 +154,8 @@ impl HttpContext for StreamContext {
self.user_prompt = Some(last_user_prompt.clone()); self.user_prompt = Some(last_user_prompt.clone());
// convert prompt targets to ChatCompletionTool // convert prompt targets to ChatCompletionTool
let tool_calls: Vec<ChatCompletionTool> = self let tool_calls: Vec<ChatCompletionTool> =
.prompt_targets self.prompt_targets.values().map(|pt| pt.into()).collect();
.iter()
.map(|(_, pt)| pt.into())
.collect();
let mut metadata = deserialized_body.metadata.clone(); let mut metadata = deserialized_body.metadata.clone();

View file

@ -376,21 +376,22 @@ impl StreamContext {
// Parse arguments JSON string into HashMap // Parse arguments JSON string into HashMap
// Note: convert from serde_json::Value to serde_yaml::Value for compatibility // Note: convert from serde_json::Value to serde_yaml::Value for compatibility
let tool_params: Option<HashMap<String, serde_yaml::Value>> = match serde_json::from_str::<HashMap<String, serde_json::Value>>(tool_params_str) { let tool_params: Option<HashMap<String, serde_yaml::Value>> =
Ok(json_params) => { match serde_json::from_str::<HashMap<String, serde_json::Value>>(tool_params_str) {
let yaml_params: HashMap<String, serde_yaml::Value> = json_params Ok(json_params) => {
.into_iter() let yaml_params: HashMap<String, serde_yaml::Value> = json_params
.filter_map(|(k, v)| { .into_iter()
serde_yaml::to_value(&v).ok().map(|yaml_v| (k, yaml_v)) .filter_map(|(k, v)| {
}) serde_yaml::to_value(&v).ok().map(|yaml_v| (k, yaml_v))
.collect(); })
Some(yaml_params) .collect();
}, Some(yaml_params)
Err(e) => { }
warn!("Failed to parse tool call arguments: {}", e); Err(e) => {
None warn!("Failed to parse tool call arguments: {}", e);
} None
}; }
};
let endpoint_details = prompt_target.endpoint.as_ref().unwrap(); let endpoint_details = prompt_target.endpoint.as_ref().unwrap();
let endpoint_path: String = endpoint_details let endpoint_path: String = endpoint_details
@ -629,10 +630,10 @@ impl StreamContext {
} }
}; };
if system_prompt.is_some() { if let Some(system_prompt_text) = system_prompt {
let system_prompt_message = Message { let system_prompt_message = Message {
role: SYSTEM_ROLE.to_string(), role: SYSTEM_ROLE.to_string(),
content: Some(ContentType::Text(system_prompt.unwrap())), content: Some(ContentType::Text(system_prompt_text)),
model: None, model: None,
tool_calls: None, tool_calls: None,
tool_call_id: None, tool_call_id: None,