mirror of
https://github.com/katanemo/plano.git
synced 2026-04-25 00:36:34 +02:00
cargo clippy (#660)
This commit is contained in:
parent
c75e7606f9
commit
ca95ffb63d
62 changed files with 1864 additions and 1187 deletions
|
|
@ -13,19 +13,22 @@ repos:
|
|||
name: cargo-fmt
|
||||
language: system
|
||||
types: [file, rust]
|
||||
entry: bash -c "cd crates/llm_gateway && cargo fmt"
|
||||
entry: bash -c "cd crates && cargo fmt --all -- --check"
|
||||
pass_filenames: false
|
||||
|
||||
- id: cargo-clippy
|
||||
name: cargo-clippy
|
||||
language: system
|
||||
types: [file, rust]
|
||||
entry: bash -c "cd crates/llm_gateway && cargo clippy --all"
|
||||
entry: bash -c "cd crates && cargo clippy --locked --offline --all-targets --all-features -- -D warnings || cargo clippy --locked --all-targets --all-features -- -D warnings"
|
||||
pass_filenames: false
|
||||
|
||||
- id: cargo-test
|
||||
name: cargo-test
|
||||
language: system
|
||||
types: [file, rust]
|
||||
entry: bash -c "cd crates && cargo test --lib"
|
||||
pass_filenames: false
|
||||
|
||||
- repo: https://github.com/psf/black
|
||||
rev: 23.1.0
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ use std::time::{Instant, SystemTime};
|
|||
|
||||
use bytes::Bytes;
|
||||
use common::consts::TRACE_PARENT_HEADER;
|
||||
use common::traces::{SpanBuilder, SpanKind, parse_traceparent, generate_random_span_id};
|
||||
use common::traces::{generate_random_span_id, parse_traceparent, SpanBuilder, SpanKind};
|
||||
use hermesllm::apis::OpenAIMessage;
|
||||
use hermesllm::clients::SupportedAPIsFromClient;
|
||||
use hermesllm::providers::request::ProviderRequest;
|
||||
|
|
@ -18,7 +18,7 @@ use super::agent_selector::{AgentSelectionError, AgentSelector};
|
|||
use super::pipeline_processor::{PipelineError, PipelineProcessor};
|
||||
use super::response_handler::ResponseHandler;
|
||||
use crate::router::plano_orchestrator::OrchestratorService;
|
||||
use crate::tracing::{OperationNameBuilder, operation_component, http};
|
||||
use crate::tracing::{http, operation_component, OperationNameBuilder};
|
||||
|
||||
/// Main errors for agent chat completions
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
|
|
@ -61,7 +61,6 @@ pub async fn agent_chat(
|
|||
body,
|
||||
}) = &err
|
||||
{
|
||||
|
||||
warn!(
|
||||
"Client error from agent '{}' (HTTP {}): {}",
|
||||
agent, status, body
|
||||
|
|
@ -77,8 +76,8 @@ pub async fn agent_chat(
|
|||
|
||||
let json_string = error_json.to_string();
|
||||
let mut response = Response::new(ResponseHandler::create_full_body(json_string));
|
||||
*response.status_mut() = hyper::StatusCode::from_u16(*status)
|
||||
.unwrap_or(hyper::StatusCode::BAD_REQUEST);
|
||||
*response.status_mut() =
|
||||
hyper::StatusCode::from_u16(*status).unwrap_or(hyper::StatusCode::BAD_REQUEST);
|
||||
response.headers_mut().insert(
|
||||
hyper::header::CONTENT_TYPE,
|
||||
"application/json".parse().unwrap(),
|
||||
|
|
@ -234,8 +233,18 @@ async fn handle_agent_chat(
|
|||
.with_attribute(http::TARGET, "/agents/select")
|
||||
.with_attribute("selection.listener", listener.name.clone())
|
||||
.with_attribute("selection.agent_count", selected_agents.len().to_string())
|
||||
.with_attribute("selection.agents", selected_agents.iter().map(|a| a.id.as_str()).collect::<Vec<_>>().join(","))
|
||||
.with_attribute("duration_ms", format!("{:.2}", selection_elapsed.as_secs_f64() * 1000.0));
|
||||
.with_attribute(
|
||||
"selection.agents",
|
||||
selected_agents
|
||||
.iter()
|
||||
.map(|a| a.id.as_str())
|
||||
.collect::<Vec<_>>()
|
||||
.join(","),
|
||||
)
|
||||
.with_attribute(
|
||||
"duration_ms",
|
||||
format!("{:.2}", selection_elapsed.as_secs_f64() * 1000.0),
|
||||
);
|
||||
|
||||
if !trace_id.is_empty() {
|
||||
selection_span_builder = selection_span_builder.with_trace_id(trace_id.clone());
|
||||
|
|
@ -318,8 +327,14 @@ async fn handle_agent_chat(
|
|||
.with_attribute(http::METHOD, "POST")
|
||||
.with_attribute(http::TARGET, full_path)
|
||||
.with_attribute("agent.name", agent_name.clone())
|
||||
.with_attribute("agent.sequence", format!("{}/{}", agent_index + 1, agent_count))
|
||||
.with_attribute("duration_ms", format!("{:.2}", agent_elapsed.as_secs_f64() * 1000.0));
|
||||
.with_attribute(
|
||||
"agent.sequence",
|
||||
format!("{}/{}", agent_index + 1, agent_count),
|
||||
)
|
||||
.with_attribute(
|
||||
"duration_ms",
|
||||
format!("{:.2}", agent_elapsed.as_secs_f64() * 1000.0),
|
||||
);
|
||||
|
||||
if !trace_id.is_empty() {
|
||||
span_builder = span_builder.with_trace_id(trace_id.clone());
|
||||
|
|
@ -333,7 +348,10 @@ async fn handle_agent_chat(
|
|||
|
||||
// If this is the last agent, return the streaming response
|
||||
if is_last_agent {
|
||||
info!("Completed agent chain, returning response from last agent: {}", agent_name);
|
||||
info!(
|
||||
"Completed agent chain, returning response from last agent: {}",
|
||||
agent_name
|
||||
);
|
||||
return response_handler
|
||||
.create_streaming_response(llm_response)
|
||||
.await
|
||||
|
|
@ -341,7 +359,10 @@ async fn handle_agent_chat(
|
|||
}
|
||||
|
||||
// For intermediate agents, collect the full response and pass to next agent
|
||||
debug!("Collecting response from intermediate agent: {}", agent_name);
|
||||
debug!(
|
||||
"Collecting response from intermediate agent: {}",
|
||||
agent_name
|
||||
);
|
||||
let response_text = response_handler.collect_full_response(llm_response).await?;
|
||||
|
||||
info!(
|
||||
|
|
@ -364,7 +385,6 @@ async fn handle_agent_chat(
|
|||
});
|
||||
|
||||
current_messages.push(last_message);
|
||||
|
||||
}
|
||||
|
||||
// This should never be reached since we return in the last agent iteration
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ use std::collections::HashMap;
|
|||
use std::sync::Arc;
|
||||
|
||||
use common::configuration::{
|
||||
Agent, AgentFilterChain, Listener, AgentUsagePreference, OrchestrationPreference,
|
||||
Agent, AgentFilterChain, AgentUsagePreference, Listener, OrchestrationPreference,
|
||||
};
|
||||
use hermesllm::apis::openai::Message;
|
||||
use tracing::{debug, warn};
|
||||
|
|
|
|||
|
|
@ -1,20 +1,18 @@
|
|||
use bytes::Bytes;
|
||||
use eventsource_stream::Eventsource;
|
||||
use futures::StreamExt;
|
||||
use hermesllm::apis::openai::{
|
||||
ChatCompletionsRequest, ChatCompletionsResponse, Choice, FinishReason, FunctionCall, Message,
|
||||
MessageContent, ResponseMessage, Role, Tool, ToolCall, Usage,
|
||||
};
|
||||
use http_body_util::{combinators::BoxBody, BodyExt, Full};
|
||||
use hyper::body::Incoming;
|
||||
use hyper::{Request, Response, StatusCode};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{json, Value};
|
||||
use std::collections::HashMap;
|
||||
use thiserror::Error;
|
||||
use tracing::{info, error};
|
||||
use futures::StreamExt;
|
||||
use bytes::Bytes;
|
||||
use http_body_util::{combinators::BoxBody, BodyExt, Full};
|
||||
use hyper::body::Incoming;
|
||||
use hyper::{Request, Response, StatusCode};
|
||||
use eventsource_stream::Eventsource;
|
||||
|
||||
|
||||
use tracing::{error, info};
|
||||
|
||||
// ============================================================================
|
||||
// CONSTANTS FOR HALLUCINATION DETECTION
|
||||
|
|
@ -273,17 +271,14 @@ impl ArchFunctionHandler {
|
|||
let mut stack: Vec<char> = Vec::new();
|
||||
let mut fixed_str = String::new();
|
||||
|
||||
let matching_bracket: HashMap<char, char> =
|
||||
[(')', '('), ('}', '{'), (']', '[')]
|
||||
.iter()
|
||||
.cloned()
|
||||
.collect();
|
||||
|
||||
let opening_bracket: HashMap<char, char> = matching_bracket
|
||||
let matching_bracket: HashMap<char, char> = [(')', '('), ('}', '{'), (']', '[')]
|
||||
.iter()
|
||||
.map(|(k, v)| (*v, *k))
|
||||
.cloned()
|
||||
.collect();
|
||||
|
||||
let opening_bracket: HashMap<char, char> =
|
||||
matching_bracket.iter().map(|(k, v)| (*v, *k)).collect();
|
||||
|
||||
for ch in json_str.chars() {
|
||||
if ch == '{' || ch == '[' || ch == '(' {
|
||||
stack.push(ch);
|
||||
|
|
@ -332,12 +327,18 @@ impl ArchFunctionHandler {
|
|||
// Remove markdown code blocks
|
||||
let mut content = content.trim().to_string();
|
||||
if content.starts_with("```") && content.ends_with("```") {
|
||||
content = content.trim_start_matches("```").trim_end_matches("```").to_string();
|
||||
content = content
|
||||
.trim_start_matches("```")
|
||||
.trim_end_matches("```")
|
||||
.to_string();
|
||||
if content.starts_with("json") {
|
||||
content = content.trim_start_matches("json").to_string();
|
||||
}
|
||||
// Trim again after removing code blocks to eliminate internal whitespace
|
||||
content = content.trim_start_matches(r"\n").trim_end_matches(r"\n").to_string();
|
||||
content = content
|
||||
.trim_start_matches(r"\n")
|
||||
.trim_end_matches(r"\n")
|
||||
.to_string();
|
||||
content = content.trim().to_string();
|
||||
// Unescape the quotes: \" -> "
|
||||
// The model sometimes returns escaped JSON inside markdown blocks
|
||||
|
|
@ -453,12 +454,12 @@ impl ArchFunctionHandler {
|
|||
/// Helper method to check if a value matches the expected type
|
||||
fn check_value_type(&self, value: &Value, target_type: &str) -> bool {
|
||||
match target_type {
|
||||
"int" | "integer" => value.is_i64() || value.is_u64(),
|
||||
"int" | "integer" => value.is_i64() || value.is_u64(),
|
||||
"float" | "number" => value.is_f64() || value.is_i64() || value.is_u64(),
|
||||
"bool" | "boolean" => value.is_boolean(),
|
||||
"str" | "string" => value.is_string(),
|
||||
"list" | "array" => value.is_array(),
|
||||
"dict" | "object" => value.is_object(),
|
||||
"bool" | "boolean" => value.is_boolean(),
|
||||
"str" | "string" => value.is_string(),
|
||||
"list" | "array" => value.is_array(),
|
||||
"dict" | "object" => value.is_object(),
|
||||
_ => true,
|
||||
}
|
||||
}
|
||||
|
|
@ -505,15 +506,19 @@ impl ArchFunctionHandler {
|
|||
let func_name = &tool_call.function.name;
|
||||
|
||||
// Parse arguments as JSON
|
||||
let func_args: HashMap<String, Value> = match serde_json::from_str(&tool_call.function.arguments) {
|
||||
Ok(args) => args,
|
||||
Err(e) => {
|
||||
verification.is_valid = false;
|
||||
verification.invalid_tool_call = Some(tool_call.clone());
|
||||
verification.error_message = format!("Failed to parse arguments for function '{}': {}", func_name, e);
|
||||
break;
|
||||
}
|
||||
};
|
||||
let func_args: HashMap<String, Value> =
|
||||
match serde_json::from_str(&tool_call.function.arguments) {
|
||||
Ok(args) => args,
|
||||
Err(e) => {
|
||||
verification.is_valid = false;
|
||||
verification.invalid_tool_call = Some(tool_call.clone());
|
||||
verification.error_message = format!(
|
||||
"Failed to parse arguments for function '{}': {}",
|
||||
func_name, e
|
||||
);
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
// Check if function is available
|
||||
if let Some(function_params) = functions.get(func_name) {
|
||||
|
|
@ -541,14 +546,23 @@ impl ArchFunctionHandler {
|
|||
if let Some(properties_obj) = properties.as_object() {
|
||||
for (param_name, param_value) in &func_args {
|
||||
if let Some(param_schema) = properties_obj.get(param_name) {
|
||||
if let Some(target_type) = param_schema.get("type").and_then(|v| v.as_str()) {
|
||||
if self.config.support_data_types.contains(&target_type.to_string()) {
|
||||
if let Some(target_type) =
|
||||
param_schema.get("type").and_then(|v| v.as_str())
|
||||
{
|
||||
if self
|
||||
.config
|
||||
.support_data_types
|
||||
.contains(&target_type.to_string())
|
||||
{
|
||||
// Validate data type using helper method
|
||||
match self.validate_or_convert_parameter(param_value, target_type) {
|
||||
match self
|
||||
.validate_or_convert_parameter(param_value, target_type)
|
||||
{
|
||||
Ok(is_valid) => {
|
||||
if !is_valid {
|
||||
verification.is_valid = false;
|
||||
verification.invalid_tool_call = Some(tool_call.clone());
|
||||
verification.invalid_tool_call =
|
||||
Some(tool_call.clone());
|
||||
verification.error_message = format!(
|
||||
"Parameter `{}` is expected to have the data type `{}`, got incompatible type.",
|
||||
param_name, target_type
|
||||
|
|
@ -558,7 +572,8 @@ impl ArchFunctionHandler {
|
|||
}
|
||||
Err(_) => {
|
||||
verification.is_valid = false;
|
||||
verification.invalid_tool_call = Some(tool_call.clone());
|
||||
verification.invalid_tool_call =
|
||||
Some(tool_call.clone());
|
||||
verification.error_message = format!(
|
||||
"Parameter `{}` is expected to have the data type `{}`, got incompatible type.",
|
||||
param_name, target_type
|
||||
|
|
@ -569,7 +584,10 @@ impl ArchFunctionHandler {
|
|||
} else {
|
||||
verification.is_valid = false;
|
||||
verification.invalid_tool_call = Some(tool_call.clone());
|
||||
verification.error_message = format!("Data type `{}` is not supported.", target_type);
|
||||
verification.error_message = format!(
|
||||
"Data type `{}` is not supported.",
|
||||
target_type
|
||||
);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
|
@ -598,11 +616,8 @@ impl ArchFunctionHandler {
|
|||
/// Formats the system prompt with tools
|
||||
pub fn format_system_prompt(&self, tools: &[Tool]) -> Result<String> {
|
||||
let tools_str = self.convert_tools(tools)?;
|
||||
let system_prompt = self
|
||||
.config
|
||||
.task_prompt
|
||||
.replace("{tools}", &tools_str)
|
||||
+ &self.config.format_prompt;
|
||||
let system_prompt =
|
||||
self.config.task_prompt.replace("{tools}", &tools_str) + &self.config.format_prompt;
|
||||
|
||||
Ok(system_prompt)
|
||||
}
|
||||
|
|
@ -665,15 +680,22 @@ impl ArchFunctionHandler {
|
|||
|
||||
// Strip markdown code blocks
|
||||
if tool_call_msg.starts_with("```") && tool_call_msg.ends_with("```") {
|
||||
tool_call_msg = tool_call_msg.trim_start_matches("```").trim_end_matches("```").trim().to_string();
|
||||
tool_call_msg = tool_call_msg
|
||||
.trim_start_matches("```")
|
||||
.trim_end_matches("```")
|
||||
.trim()
|
||||
.to_string();
|
||||
if tool_call_msg.starts_with("json") {
|
||||
tool_call_msg = tool_call_msg.trim_start_matches("json").trim().to_string();
|
||||
tool_call_msg =
|
||||
tool_call_msg.trim_start_matches("json").trim().to_string();
|
||||
}
|
||||
}
|
||||
|
||||
// Extract function name
|
||||
if let Ok(parsed) = serde_json::from_str::<Value>(&tool_call_msg) {
|
||||
if let Some(tool_calls_arr) = parsed.get("tool_calls").and_then(|v| v.as_array()) {
|
||||
if let Some(tool_calls_arr) =
|
||||
parsed.get("tool_calls").and_then(|v| v.as_array())
|
||||
{
|
||||
if let Some(first_tool_call) = tool_calls_arr.first() {
|
||||
let func_name = first_tool_call
|
||||
.get("name")
|
||||
|
|
@ -685,8 +707,10 @@ impl ArchFunctionHandler {
|
|||
"result": content,
|
||||
});
|
||||
|
||||
content = format!("<tool_response>\n{}\n</tool_response>",
|
||||
serde_json::to_string(&tool_response)?);
|
||||
content = format!(
|
||||
"<tool_response>\n{}\n</tool_response>",
|
||||
serde_json::to_string(&tool_response)?
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -717,7 +741,7 @@ impl ArchFunctionHandler {
|
|||
if let Some(instruction) = extra_instruction {
|
||||
if let Some(last) = processed_messages.last_mut() {
|
||||
if let MessageContent::Text(content) = &mut last.content {
|
||||
content.push_str("\n");
|
||||
content.push('\n');
|
||||
content.push_str(instruction);
|
||||
}
|
||||
}
|
||||
|
|
@ -750,13 +774,11 @@ impl ArchFunctionHandler {
|
|||
for i in (conversation_idx..messages.len()).rev() {
|
||||
if let MessageContent::Text(content) = &messages[i].content {
|
||||
num_tokens += content.len() / 4;
|
||||
if num_tokens >= max_tokens {
|
||||
if messages[i].role == Role::User {
|
||||
// Set message_idx to current position and break
|
||||
// This matches Python's behavior where message_idx is set before break
|
||||
message_idx = i;
|
||||
break;
|
||||
}
|
||||
if num_tokens >= max_tokens && messages[i].role == Role::User {
|
||||
// Set message_idx to current position and break
|
||||
// This matches Python's behavior where message_idx is set before break
|
||||
message_idx = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
// Only update message_idx if we haven't hit the token limit yet
|
||||
|
|
@ -789,7 +811,11 @@ impl ArchFunctionHandler {
|
|||
}
|
||||
|
||||
/// Helper to create a request with VLLM-specific parameters
|
||||
fn create_request_with_extra_body(&self, messages: Vec<Message>, stream: bool) -> ChatCompletionsRequest {
|
||||
fn create_request_with_extra_body(
|
||||
&self,
|
||||
messages: Vec<Message>,
|
||||
stream: bool,
|
||||
) -> ChatCompletionsRequest {
|
||||
ChatCompletionsRequest {
|
||||
model: self.model_name.clone(),
|
||||
messages,
|
||||
|
|
@ -813,24 +839,38 @@ impl ArchFunctionHandler {
|
|||
}
|
||||
|
||||
/// Makes a streaming request and returns the SSE event stream
|
||||
async fn make_streaming_request(&self, request: ChatCompletionsRequest) -> Result<std::pin::Pin<Box<dyn futures::Stream<Item = std::result::Result<Value, String>> + Send>>> {
|
||||
let request_body = serde_json::to_string(&request)
|
||||
.map_err(|e| FunctionCallingError::InvalidModelResponse(format!("Failed to serialize request: {}", e)))?;
|
||||
async fn make_streaming_request(
|
||||
&self,
|
||||
request: ChatCompletionsRequest,
|
||||
) -> Result<
|
||||
std::pin::Pin<Box<dyn futures::Stream<Item = std::result::Result<Value, String>> + Send>>,
|
||||
> {
|
||||
let request_body = serde_json::to_string(&request).map_err(|e| {
|
||||
FunctionCallingError::InvalidModelResponse(format!(
|
||||
"Failed to serialize request: {}",
|
||||
e
|
||||
))
|
||||
})?;
|
||||
|
||||
let response = self.http_client
|
||||
let response = self
|
||||
.http_client
|
||||
.post(&self.endpoint_url)
|
||||
.header("Content-Type", "application/json")
|
||||
.body(request_body)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| FunctionCallingError::HttpError(e))?;
|
||||
.map_err(FunctionCallingError::HttpError)?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let error_text = response.text().await.unwrap_or_else(|_| "Unknown error".to_string());
|
||||
return Err(FunctionCallingError::InvalidModelResponse(
|
||||
format!("HTTP error {}: {}", status, error_text)
|
||||
));
|
||||
let error_text = response
|
||||
.text()
|
||||
.await
|
||||
.unwrap_or_else(|_| "Unknown error".to_string());
|
||||
return Err(FunctionCallingError::InvalidModelResponse(format!(
|
||||
"HTTP error {}: {}",
|
||||
status, error_text
|
||||
)));
|
||||
}
|
||||
|
||||
// Parse SSE stream
|
||||
|
|
@ -856,38 +896,51 @@ impl ArchFunctionHandler {
|
|||
}
|
||||
|
||||
/// Makes a non-streaming request and returns the response
|
||||
async fn make_non_streaming_request(&self, request: ChatCompletionsRequest) -> Result<ChatCompletionsResponse> {
|
||||
let request_body = serde_json::to_string(&request)
|
||||
.map_err(|e| FunctionCallingError::InvalidModelResponse(format!("Failed to serialize request: {}", e)))?;
|
||||
async fn make_non_streaming_request(
|
||||
&self,
|
||||
request: ChatCompletionsRequest,
|
||||
) -> Result<ChatCompletionsResponse> {
|
||||
let request_body = serde_json::to_string(&request).map_err(|e| {
|
||||
FunctionCallingError::InvalidModelResponse(format!(
|
||||
"Failed to serialize request: {}",
|
||||
e
|
||||
))
|
||||
})?;
|
||||
|
||||
let response = self.http_client
|
||||
let response = self
|
||||
.http_client
|
||||
.post(&self.endpoint_url)
|
||||
.header("Content-Type", "application/json")
|
||||
.body(request_body)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| FunctionCallingError::HttpError(e))?;
|
||||
.map_err(FunctionCallingError::HttpError)?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let error_text = response.text().await.unwrap_or_else(|_| "Unknown error".to_string());
|
||||
return Err(FunctionCallingError::InvalidModelResponse(
|
||||
format!("HTTP error {}: {}", status, error_text)
|
||||
));
|
||||
let error_text = response
|
||||
.text()
|
||||
.await
|
||||
.unwrap_or_else(|_| "Unknown error".to_string());
|
||||
return Err(FunctionCallingError::InvalidModelResponse(format!(
|
||||
"HTTP error {}: {}",
|
||||
status, error_text
|
||||
)));
|
||||
}
|
||||
|
||||
let response_text = response.text().await
|
||||
.map_err(|e| FunctionCallingError::HttpError(e))?;
|
||||
let response_text = response
|
||||
.text()
|
||||
.await
|
||||
.map_err(FunctionCallingError::HttpError)?;
|
||||
|
||||
serde_json::from_str(&response_text)
|
||||
.map_err(|e| FunctionCallingError::JsonParseError(e))
|
||||
serde_json::from_str(&response_text).map_err(FunctionCallingError::JsonParseError)
|
||||
}
|
||||
|
||||
pub async fn function_calling_chat(
|
||||
&self,
|
||||
request: ChatCompletionsRequest,
|
||||
) -> Result<ChatCompletionsResponse> {
|
||||
use tracing::{info, error};
|
||||
use tracing::{error, info};
|
||||
|
||||
info!("[Arch-Function] - ChatCompletion");
|
||||
|
||||
|
|
@ -899,10 +952,14 @@ impl ArchFunctionHandler {
|
|||
request.metadata.as_ref(),
|
||||
)?;
|
||||
|
||||
info!("[request to arch-fc]: model: {}, messages count: {}",
|
||||
self.model_name, messages.len());
|
||||
info!(
|
||||
"[request to arch-fc]: model: {}, messages count: {}",
|
||||
self.model_name,
|
||||
messages.len()
|
||||
);
|
||||
|
||||
let use_agent_orchestrator = request.metadata
|
||||
let use_agent_orchestrator = request
|
||||
.metadata
|
||||
.as_ref()
|
||||
.and_then(|m| m.get("use_agent_orchestrator"))
|
||||
.and_then(|v| v.as_bool())
|
||||
|
|
@ -918,89 +975,95 @@ impl ArchFunctionHandler {
|
|||
|
||||
if use_agent_orchestrator {
|
||||
while let Some(chunk_result) = stream.next().await {
|
||||
let chunk = chunk_result.map_err(|e| FunctionCallingError::InvalidModelResponse(e))?;
|
||||
let chunk = chunk_result.map_err(FunctionCallingError::InvalidModelResponse)?;
|
||||
// Extract content from JSON response
|
||||
if let Some(choices) = chunk.get("choices").and_then(|v| v.as_array()) {
|
||||
if let Some(choice) = choices.first() {
|
||||
if let Some(content) = choice.get("delta")
|
||||
if let Some(content) = choice
|
||||
.get("delta")
|
||||
.and_then(|d| d.get("content"))
|
||||
.and_then(|c| c.as_str()) {
|
||||
.and_then(|c| c.as_str())
|
||||
{
|
||||
model_response.push_str(content);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
info!("[Agent Orchestrator]: response received");
|
||||
} else {
|
||||
if let Some(tools) = request.tools.as_ref() {
|
||||
let mut hallucination_state = HallucinationState::new(tools);
|
||||
let mut has_tool_calls = None;
|
||||
let mut has_hallucination = false;
|
||||
} else if let Some(tools) = request.tools.as_ref() {
|
||||
let mut hallucination_state = HallucinationState::new(tools);
|
||||
let mut has_tool_calls = None;
|
||||
let mut has_hallucination = false;
|
||||
|
||||
while let Some(chunk_result) = stream.next().await {
|
||||
let chunk = chunk_result.map_err(|e| FunctionCallingError::InvalidModelResponse(e))?;
|
||||
while let Some(chunk_result) = stream.next().await {
|
||||
let chunk = chunk_result.map_err(FunctionCallingError::InvalidModelResponse)?;
|
||||
|
||||
// Extract content and logprobs from JSON response
|
||||
if let Some(choices) = chunk.get("choices").and_then(|v| v.as_array()) {
|
||||
if let Some(choice) = choices.first() {
|
||||
if let Some(content) = choice.get("delta")
|
||||
.and_then(|d| d.get("content"))
|
||||
.and_then(|c| c.as_str()) {
|
||||
// Extract content and logprobs from JSON response
|
||||
if let Some(choices) = chunk.get("choices").and_then(|v| v.as_array()) {
|
||||
if let Some(choice) = choices.first() {
|
||||
if let Some(content) = choice
|
||||
.get("delta")
|
||||
.and_then(|d| d.get("content"))
|
||||
.and_then(|c| c.as_str())
|
||||
{
|
||||
// Extract logprobs
|
||||
let logprobs: Vec<f64> = choice
|
||||
.get("logprobs")
|
||||
.and_then(|lp| lp.get("content"))
|
||||
.and_then(|c| c.as_array())
|
||||
.and_then(|arr| arr.first())
|
||||
.and_then(|token| token.get("top_logprobs"))
|
||||
.and_then(|tlp| tlp.as_array())
|
||||
.map(|arr| {
|
||||
arr.iter()
|
||||
.filter_map(|v| v.get("logprob").and_then(|lp| lp.as_f64()))
|
||||
.collect()
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
// Extract logprobs
|
||||
let logprobs: Vec<f64> = choice.get("logprobs")
|
||||
.and_then(|lp| lp.get("content"))
|
||||
.and_then(|c| c.as_array())
|
||||
.and_then(|arr| arr.first())
|
||||
.and_then(|token| token.get("top_logprobs"))
|
||||
.and_then(|tlp| tlp.as_array())
|
||||
.map(|arr| {
|
||||
arr.iter()
|
||||
.filter_map(|v| v.get("logprob").and_then(|lp| lp.as_f64()))
|
||||
.collect()
|
||||
})
|
||||
.unwrap_or_default();
|
||||
if hallucination_state
|
||||
.append_and_check_token_hallucination(content.to_string(), logprobs)
|
||||
{
|
||||
has_hallucination = true;
|
||||
break;
|
||||
}
|
||||
|
||||
if hallucination_state.append_and_check_token_hallucination(content.to_string(), logprobs) {
|
||||
has_hallucination = true;
|
||||
break;
|
||||
}
|
||||
|
||||
if hallucination_state.tokens.len() > 5 && has_tool_calls.is_none() {
|
||||
let collected_content = hallucination_state.tokens.join("");
|
||||
has_tool_calls = Some(collected_content.contains("tool_calls"));
|
||||
}
|
||||
if hallucination_state.tokens.len() > 5 && has_tool_calls.is_none() {
|
||||
let collected_content = hallucination_state.tokens.join("");
|
||||
has_tool_calls = Some(collected_content.contains("tool_calls"));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if has_tool_calls == Some(true) && has_hallucination {
|
||||
info!("[Hallucination]: {}", hallucination_state.error_message);
|
||||
if has_tool_calls == Some(true) && has_hallucination {
|
||||
info!("[Hallucination]: {}", hallucination_state.error_message);
|
||||
|
||||
let clarify_messages = self.prefill_message(messages.clone(), &self.clarify_prefix);
|
||||
let clarify_request = self.create_request_with_extra_body(clarify_messages, false);
|
||||
let clarify_messages = self.prefill_message(messages.clone(), &self.clarify_prefix);
|
||||
let clarify_request = self.create_request_with_extra_body(clarify_messages, false);
|
||||
|
||||
let retry_response = self.make_non_streaming_request(clarify_request).await?;
|
||||
let retry_response = self.make_non_streaming_request(clarify_request).await?;
|
||||
|
||||
if let Some(choice) = retry_response.choices.first() {
|
||||
if let Some(content) = &choice.message.content {
|
||||
model_response = content.clone();
|
||||
}
|
||||
if let Some(choice) = retry_response.choices.first() {
|
||||
if let Some(content) = &choice.message.content {
|
||||
model_response = content.clone();
|
||||
}
|
||||
} else {
|
||||
model_response = hallucination_state.tokens.join("");
|
||||
}
|
||||
} else {
|
||||
while let Some(chunk_result) = stream.next().await {
|
||||
let chunk = chunk_result.map_err(|e| FunctionCallingError::InvalidModelResponse(e))?;
|
||||
if let Some(choices) = chunk.get("choices").and_then(|v| v.as_array()) {
|
||||
if let Some(choice) = choices.first() {
|
||||
if let Some(content) = choice.get("delta")
|
||||
.and_then(|d| d.get("content"))
|
||||
.and_then(|c| c.as_str()) {
|
||||
model_response.push_str(content);
|
||||
}
|
||||
model_response = hallucination_state.tokens.join("");
|
||||
}
|
||||
} else {
|
||||
while let Some(chunk_result) = stream.next().await {
|
||||
let chunk = chunk_result.map_err(FunctionCallingError::InvalidModelResponse)?;
|
||||
if let Some(choices) = chunk.get("choices").and_then(|v| v.as_array()) {
|
||||
if let Some(choice) = choices.first() {
|
||||
if let Some(content) = choice
|
||||
.get("delta")
|
||||
.and_then(|d| d.get("content"))
|
||||
.and_then(|c| c.as_str())
|
||||
{
|
||||
model_response.push_str(content);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1009,10 +1072,17 @@ impl ArchFunctionHandler {
|
|||
|
||||
let response_dict = self.parse_model_response(&model_response);
|
||||
|
||||
info!("[arch-fc]: raw model response: {}", response_dict.raw_response);
|
||||
info!(
|
||||
"[arch-fc]: raw model response: {}",
|
||||
response_dict.raw_response
|
||||
);
|
||||
|
||||
// General model response (no intent matched - should route to default target)
|
||||
let model_message = if response_dict.response.as_ref().map_or(false, |s| !s.is_empty()) {
|
||||
let model_message = if response_dict
|
||||
.response
|
||||
.as_ref()
|
||||
.is_some_and(|s| !s.is_empty())
|
||||
{
|
||||
// When arch-fc returns a "response" field, it means no intent was matched
|
||||
// Return empty content and empty tool_calls so prompt_gateway routes to default target
|
||||
ResponseMessage {
|
||||
|
|
@ -1053,8 +1123,11 @@ impl ArchFunctionHandler {
|
|||
let verification = self.verify_tool_calls(tools, &response_dict.tool_calls);
|
||||
|
||||
if verification.is_valid {
|
||||
info!("[Tool calls]: {:?}",
|
||||
response_dict.tool_calls.iter()
|
||||
info!(
|
||||
"[Tool calls]: {:?}",
|
||||
response_dict
|
||||
.tool_calls
|
||||
.iter()
|
||||
.map(|tc| &tc.function)
|
||||
.collect::<Vec<_>>()
|
||||
);
|
||||
|
|
@ -1092,8 +1165,11 @@ impl ArchFunctionHandler {
|
|||
}
|
||||
}
|
||||
} else {
|
||||
info!("[Tool calls]: {:?}",
|
||||
response_dict.tool_calls.iter()
|
||||
info!(
|
||||
"[Tool calls]: {:?}",
|
||||
response_dict
|
||||
.tool_calls
|
||||
.iter()
|
||||
.map(|tc| &tc.function)
|
||||
.collect::<Vec<_>>()
|
||||
);
|
||||
|
|
@ -1108,7 +1184,10 @@ impl ArchFunctionHandler {
|
|||
}
|
||||
}
|
||||
} else {
|
||||
error!("Invalid tool calls in response: {}", response_dict.error_message);
|
||||
error!(
|
||||
"Invalid tool calls in response: {}",
|
||||
response_dict.error_message
|
||||
);
|
||||
ResponseMessage {
|
||||
role: Role::Assistant,
|
||||
content: Some(String::new()),
|
||||
|
|
@ -1243,7 +1322,6 @@ pub async fn function_calling_chat_handler(
|
|||
req: Request<Incoming>,
|
||||
llm_provider_url: String,
|
||||
) -> std::result::Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||
|
||||
use hermesllm::apis::openai::ChatCompletionsRequest;
|
||||
let whole_body = req.collect().await?.to_bytes();
|
||||
|
||||
|
|
@ -1255,10 +1333,13 @@ pub async fn function_calling_chat_handler(
|
|||
let mut response = Response::new(full(
|
||||
serde_json::json!({
|
||||
"error": format!("Invalid request body: {}", e)
|
||||
}).to_string()
|
||||
})
|
||||
.to_string(),
|
||||
));
|
||||
*response.status_mut() = StatusCode::BAD_REQUEST;
|
||||
response.headers_mut().insert("Content-Type", "application/json".parse().unwrap());
|
||||
response
|
||||
.headers_mut()
|
||||
.insert("Content-Type", "application/json".parse().unwrap());
|
||||
return Ok(response);
|
||||
}
|
||||
};
|
||||
|
|
@ -1271,24 +1352,31 @@ pub async fn function_calling_chat_handler(
|
|||
// Parse as ChatCompletionsRequest
|
||||
let chat_request: ChatCompletionsRequest = match serde_json::from_value(body_json) {
|
||||
Ok(req) => {
|
||||
info!("[request body]: {}", serde_json::to_string(&req).unwrap_or_default());
|
||||
info!(
|
||||
"[request body]: {}",
|
||||
serde_json::to_string(&req).unwrap_or_default()
|
||||
);
|
||||
req
|
||||
},
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Failed to parse request body: {}", e);
|
||||
let mut response = Response::new(full(
|
||||
serde_json::json!({
|
||||
"error": format!("Invalid request body: {}", e)
|
||||
}).to_string()
|
||||
})
|
||||
.to_string(),
|
||||
));
|
||||
*response.status_mut() = StatusCode::BAD_REQUEST;
|
||||
response.headers_mut().insert("Content-Type", "application/json".parse().unwrap());
|
||||
response
|
||||
.headers_mut()
|
||||
.insert("Content-Type", "application/json".parse().unwrap());
|
||||
return Ok(response);
|
||||
}
|
||||
};
|
||||
|
||||
// Determine which handler to use based on metadata
|
||||
let use_agent_orchestrator = chat_request.metadata
|
||||
let use_agent_orchestrator = chat_request
|
||||
.metadata
|
||||
.as_ref()
|
||||
.and_then(|m| m.get("use_agent_orchestrator"))
|
||||
.and_then(|v| v.as_bool())
|
||||
|
|
@ -1309,7 +1397,10 @@ pub async fn function_calling_chat_handler(
|
|||
ARCH_FUNCTION_MODEL_NAME.to_string(),
|
||||
llm_provider_url.clone(),
|
||||
);
|
||||
handler.function_handler.function_calling_chat(chat_request).await
|
||||
handler
|
||||
.function_handler
|
||||
.function_calling_chat(chat_request)
|
||||
.await
|
||||
} else {
|
||||
let handler = ArchFunctionHandler::new(
|
||||
ARCH_FUNCTION_MODEL_NAME.to_string(),
|
||||
|
|
@ -1328,7 +1419,9 @@ pub async fn function_calling_chat_handler(
|
|||
|
||||
let mut response = Response::new(full(response_json));
|
||||
*response.status_mut() = StatusCode::OK;
|
||||
response.headers_mut().insert("Content-Type", "application/json".parse().unwrap());
|
||||
response
|
||||
.headers_mut()
|
||||
.insert("Content-Type", "application/json".parse().unwrap());
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
|
|
@ -1341,13 +1434,14 @@ pub async fn function_calling_chat_handler(
|
|||
|
||||
let mut response = Response::new(full(error_response.to_string()));
|
||||
*response.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
|
||||
response.headers_mut().insert("Content-Type", "application/json".parse().unwrap());
|
||||
response
|
||||
.headers_mut()
|
||||
.insert("Content-Type", "application/json".parse().unwrap());
|
||||
Ok(response)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// ============================================================================
|
||||
// TESTS
|
||||
// ============================================================================
|
||||
|
|
@ -1370,10 +1464,13 @@ mod tests {
|
|||
assert!(config.task_prompt.contains("</tools>\\n\\n"));
|
||||
|
||||
// Format prompt should contain literal escaped newlines and proper JSON examples
|
||||
assert!(config.format_prompt.contains("\\n\\nBased on your analysis"));
|
||||
assert!(config.format_prompt.contains(r#"{\"response\": \"Your response text here\"}"#));
|
||||
assert!(config
|
||||
.format_prompt
|
||||
.contains("\\n\\nBased on your analysis"));
|
||||
assert!(config
|
||||
.format_prompt
|
||||
.contains(r#"{\"response\": \"Your response text here\"}"#));
|
||||
assert!(config.format_prompt.contains(r#"{\"tool_calls\": [{"#));
|
||||
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -1384,7 +1481,11 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn test_fix_json_string_valid() {
|
||||
let handler = ArchFunctionHandler::new("test-model".to_string(), ArchFunctionConfig::default(), "http://localhost:8000".to_string());
|
||||
let handler = ArchFunctionHandler::new(
|
||||
"test-model".to_string(),
|
||||
ArchFunctionConfig::default(),
|
||||
"http://localhost:8000".to_string(),
|
||||
);
|
||||
let json_str = r#"{"name": "test", "value": 123}"#;
|
||||
let result = handler.fix_json_string(json_str);
|
||||
assert!(result.is_ok());
|
||||
|
|
@ -1392,7 +1493,11 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn test_fix_json_string_missing_bracket() {
|
||||
let handler = ArchFunctionHandler::new("test-model".to_string(), ArchFunctionConfig::default(), "http://localhost:8000".to_string());
|
||||
let handler = ArchFunctionHandler::new(
|
||||
"test-model".to_string(),
|
||||
ArchFunctionConfig::default(),
|
||||
"http://localhost:8000".to_string(),
|
||||
);
|
||||
let json_str = r#"{"name": "test", "value": 123"#;
|
||||
let result = handler.fix_json_string(json_str);
|
||||
assert!(result.is_ok());
|
||||
|
|
@ -1402,8 +1507,13 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn test_parse_model_response_with_tool_calls() {
|
||||
let handler = ArchFunctionHandler::new("test-model".to_string(), ArchFunctionConfig::default(), "http://localhost:8000".to_string());
|
||||
let content = r#"{"tool_calls": [{"name": "get_weather", "arguments": {"location": "NYC"}}]}"#;
|
||||
let handler = ArchFunctionHandler::new(
|
||||
"test-model".to_string(),
|
||||
ArchFunctionConfig::default(),
|
||||
"http://localhost:8000".to_string(),
|
||||
);
|
||||
let content =
|
||||
r#"{"tool_calls": [{"name": "get_weather", "arguments": {"location": "NYC"}}]}"#;
|
||||
let result = handler.parse_model_response(content);
|
||||
|
||||
assert!(result.is_valid);
|
||||
|
|
@ -1413,8 +1523,13 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn test_parse_model_response_with_clarification() {
|
||||
let handler = ArchFunctionHandler::new("test-model".to_string(), ArchFunctionConfig::default(), "http://localhost:8000".to_string());
|
||||
let content = r#"{"required_functions": ["get_weather"], "clarification": "What location?"}"#;
|
||||
let handler = ArchFunctionHandler::new(
|
||||
"test-model".to_string(),
|
||||
ArchFunctionConfig::default(),
|
||||
"http://localhost:8000".to_string(),
|
||||
);
|
||||
let content =
|
||||
r#"{"required_functions": ["get_weather"], "clarification": "What location?"}"#;
|
||||
let result = handler.parse_model_response(content);
|
||||
|
||||
assert!(result.is_valid);
|
||||
|
|
@ -1424,7 +1539,11 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn test_convert_data_type_int_to_float() {
|
||||
let handler = ArchFunctionHandler::new("test-model".to_string(), ArchFunctionConfig::default(), "http://localhost:8000".to_string());
|
||||
let handler = ArchFunctionHandler::new(
|
||||
"test-model".to_string(),
|
||||
ArchFunctionConfig::default(),
|
||||
"http://localhost:8000".to_string(),
|
||||
);
|
||||
let value = json!(42);
|
||||
let result = handler.convert_data_type(&value, "float");
|
||||
assert!(result.is_ok());
|
||||
|
|
@ -1504,13 +1623,12 @@ pub fn check_threshold(
|
|||
}
|
||||
|
||||
/// Checks if a parameter is required in the function description
|
||||
pub fn is_parameter_required(
|
||||
function_description: &Value,
|
||||
parameter_name: &str,
|
||||
) -> bool {
|
||||
pub fn is_parameter_required(function_description: &Value, parameter_name: &str) -> bool {
|
||||
if let Some(required) = function_description.get("required") {
|
||||
if let Some(required_arr) = required.as_array() {
|
||||
return required_arr.iter().any(|v| v.as_str() == Some(parameter_name));
|
||||
return required_arr
|
||||
.iter()
|
||||
.any(|v| v.as_str() == Some(parameter_name));
|
||||
}
|
||||
}
|
||||
false
|
||||
|
|
@ -1559,12 +1677,7 @@ impl HallucinationState {
|
|||
pub fn new(functions: &[Tool]) -> Self {
|
||||
let function_properties: HashMap<String, Value> = functions
|
||||
.iter()
|
||||
.map(|tool| {
|
||||
(
|
||||
tool.function.name.clone(),
|
||||
tool.function.parameters.clone(),
|
||||
)
|
||||
})
|
||||
.map(|tool| (tool.function.name.clone(), tool.function.parameters.clone()))
|
||||
.collect();
|
||||
|
||||
Self {
|
||||
|
|
@ -1620,7 +1733,10 @@ impl HallucinationState {
|
|||
|
||||
// Function name extraction logic
|
||||
if self.state.as_deref() == Some("function_name") {
|
||||
if !FUNC_NAME_END_TOKEN.iter().any(|&t| self.tokens.last().map_or(false, |tok| tok == t)) {
|
||||
if !FUNC_NAME_END_TOKEN
|
||||
.iter()
|
||||
.any(|&t| self.tokens.last().is_some_and(|tok| tok == t))
|
||||
{
|
||||
self.mask.push(MaskToken::FunctionName);
|
||||
} else {
|
||||
self.state = None;
|
||||
|
|
@ -1629,34 +1745,51 @@ impl HallucinationState {
|
|||
}
|
||||
|
||||
// Check for function name start
|
||||
if FUNC_NAME_START_PATTERN.iter().any(|&p| content.ends_with(p)) {
|
||||
if FUNC_NAME_START_PATTERN
|
||||
.iter()
|
||||
.any(|&p| content.ends_with(p))
|
||||
{
|
||||
self.state = Some("function_name".to_string());
|
||||
}
|
||||
|
||||
// Parameter name extraction logic
|
||||
if self.state.as_deref() == Some("parameter_name")
|
||||
&& !PARAMETER_NAME_END_TOKENS.iter().any(|&t| content.ends_with(t)) {
|
||||
&& !PARAMETER_NAME_END_TOKENS
|
||||
.iter()
|
||||
.any(|&t| content.ends_with(t))
|
||||
{
|
||||
self.mask.push(MaskToken::ParameterName);
|
||||
} else if self.state.as_deref() == Some("parameter_name")
|
||||
&& PARAMETER_NAME_END_TOKENS.iter().any(|&t| content.ends_with(t)) {
|
||||
&& PARAMETER_NAME_END_TOKENS
|
||||
.iter()
|
||||
.any(|&t| content.ends_with(t))
|
||||
{
|
||||
self.state = None;
|
||||
self.parameter_name_done = true;
|
||||
self.get_parameter_name();
|
||||
} else if self.parameter_name_done
|
||||
&& !self.open_bracket
|
||||
&& PARAMETER_NAME_START_PATTERN.iter().any(|&p| content.ends_with(p)) {
|
||||
&& PARAMETER_NAME_START_PATTERN
|
||||
.iter()
|
||||
.any(|&p| content.ends_with(p))
|
||||
{
|
||||
self.state = Some("parameter_name".to_string());
|
||||
}
|
||||
|
||||
// First parameter value start
|
||||
if FIRST_PARAM_NAME_START_PATTERN.iter().any(|&p| content.ends_with(p)) {
|
||||
if FIRST_PARAM_NAME_START_PATTERN
|
||||
.iter()
|
||||
.any(|&p| content.ends_with(p))
|
||||
{
|
||||
self.state = Some("parameter_name".to_string());
|
||||
}
|
||||
|
||||
// Parameter value extraction logic
|
||||
if self.state.as_deref() == Some("parameter_value")
|
||||
&& !PARAMETER_VALUE_END_TOKEN.iter().any(|&t| content.ends_with(t)) {
|
||||
|
||||
&& !PARAMETER_VALUE_END_TOKEN
|
||||
.iter()
|
||||
.any(|&t| content.ends_with(t))
|
||||
{
|
||||
// Check for brackets
|
||||
if let Some(last_token) = self.tokens.last() {
|
||||
let open_brackets: Vec<char> = last_token
|
||||
|
|
@ -1694,8 +1827,11 @@ impl HallucinationState {
|
|||
&& self.mask[self.mask.len() - 2] != MaskToken::ParameterValue
|
||||
&& !self.parameter_name.is_empty()
|
||||
{
|
||||
let last_param = self.parameter_name[self.parameter_name.len() - 1].clone();
|
||||
if let Some(func_props) = self.function_properties.get(&self.function_name) {
|
||||
let last_param =
|
||||
self.parameter_name[self.parameter_name.len() - 1].clone();
|
||||
if let Some(func_props) =
|
||||
self.function_properties.get(&self.function_name)
|
||||
{
|
||||
if is_parameter_required(func_props, &last_param)
|
||||
&& !is_parameter_property(func_props, &last_param, "enum")
|
||||
&& !self.check_parameter_name.contains_key(&last_param)
|
||||
|
|
@ -1718,10 +1854,16 @@ impl HallucinationState {
|
|||
}
|
||||
} else if self.state.as_deref() == Some("parameter_value")
|
||||
&& !self.open_bracket
|
||||
&& PARAMETER_VALUE_END_TOKEN.iter().any(|&t| content.ends_with(t)) {
|
||||
&& PARAMETER_VALUE_END_TOKEN
|
||||
.iter()
|
||||
.any(|&t| content.ends_with(t))
|
||||
{
|
||||
self.state = None;
|
||||
} else if self.parameter_name_done
|
||||
&& PARAMETER_VALUE_START_PATTERN.iter().any(|&p| content.ends_with(p)) {
|
||||
&& PARAMETER_VALUE_START_PATTERN
|
||||
.iter()
|
||||
.any(|&p| content.ends_with(p))
|
||||
{
|
||||
self.state = Some("parameter_value".to_string());
|
||||
}
|
||||
|
||||
|
|
@ -1848,18 +1990,18 @@ mod hallucination_tests {
|
|||
let handler = ArchFunctionHandler::new(
|
||||
"test-model".to_string(),
|
||||
ArchFunctionConfig::default(),
|
||||
"http://localhost:8000".to_string()
|
||||
"http://localhost:8000".to_string(),
|
||||
);
|
||||
|
||||
// Test integer types
|
||||
assert!(handler.check_value_type(&json!(42), "integer"));
|
||||
assert!(handler.check_value_type(&json!(42), "int"));
|
||||
assert!(!handler.check_value_type(&json!(3.14), "integer"));
|
||||
assert!(!handler.check_value_type(&json!(3.15), "integer"));
|
||||
|
||||
// Test number types (accepts both int and float)
|
||||
assert!(handler.check_value_type(&json!(3.14), "number"));
|
||||
assert!(handler.check_value_type(&json!(3.15), "number"));
|
||||
assert!(handler.check_value_type(&json!(42), "number"));
|
||||
assert!(handler.check_value_type(&json!(3.14), "float"));
|
||||
assert!(handler.check_value_type(&json!(3.15), "float"));
|
||||
|
||||
// Test boolean
|
||||
assert!(handler.check_value_type(&json!(true), "boolean"));
|
||||
|
|
@ -1890,12 +2032,16 @@ mod hallucination_tests {
|
|||
let handler = ArchFunctionHandler::new(
|
||||
"test-model".to_string(),
|
||||
ArchFunctionConfig::default(),
|
||||
"http://localhost:8000".to_string()
|
||||
"http://localhost:8000".to_string(),
|
||||
);
|
||||
|
||||
// Test valid type - no conversion needed
|
||||
assert!(handler.validate_or_convert_parameter(&json!(42), "integer").unwrap());
|
||||
assert!(handler.validate_or_convert_parameter(&json!("hello"), "string").unwrap());
|
||||
assert!(handler
|
||||
.validate_or_convert_parameter(&json!(42), "integer")
|
||||
.unwrap());
|
||||
assert!(handler
|
||||
.validate_or_convert_parameter(&json!("hello"), "string")
|
||||
.unwrap());
|
||||
|
||||
// Test integer to float conversion (convert_data_type supports this)
|
||||
let result = handler.validate_or_convert_parameter(&json!(42), "float");
|
||||
|
|
@ -1910,8 +2056,12 @@ mod hallucination_tests {
|
|||
assert!(!result.unwrap());
|
||||
|
||||
// Test number accepting both int and float
|
||||
assert!(handler.validate_or_convert_parameter(&json!(42), "number").unwrap());
|
||||
assert!(handler.validate_or_convert_parameter(&json!(3.14), "number").unwrap());
|
||||
assert!(handler
|
||||
.validate_or_convert_parameter(&json!(42), "number")
|
||||
.unwrap());
|
||||
assert!(handler
|
||||
.validate_or_convert_parameter(&json!(3.15), "number")
|
||||
.unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ use crate::router::plano_orchestrator::OrchestratorService;
|
|||
/// 2. PipelineProcessor - executes the agent pipeline
|
||||
/// 3. ResponseHandler - handles response streaming
|
||||
#[cfg(test)]
|
||||
mod integration_tests {
|
||||
mod tests {
|
||||
use super::*;
|
||||
use common::configuration::{Agent, AgentFilterChain, Listener};
|
||||
|
||||
|
|
@ -62,7 +62,10 @@ mod integration_tests {
|
|||
|
||||
let agent_pipeline = AgentFilterChain {
|
||||
id: "terminal-agent".to_string(),
|
||||
filter_chain: Some(vec!["filter-agent".to_string(), "terminal-agent".to_string()]),
|
||||
filter_chain: Some(vec![
|
||||
"filter-agent".to_string(),
|
||||
"terminal-agent".to_string(),
|
||||
]),
|
||||
description: Some("Test pipeline".to_string()),
|
||||
default: Some(true),
|
||||
};
|
||||
|
|
|
|||
|
|
@ -2,48 +2,48 @@ use serde::{Deserialize, Serialize};
|
|||
use std::collections::HashMap;
|
||||
|
||||
pub const JSON_RPC_VERSION: &str = "2.0";
|
||||
pub const TOOL_CALL_METHOD : &str = "tools/call";
|
||||
pub const TOOL_CALL_METHOD: &str = "tools/call";
|
||||
pub const MCP_INITIALIZE: &str = "initialize";
|
||||
pub const MCP_INITIALIZE_NOTIFICATION: &str = "notifications/initialized";
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum JsonRpcId {
|
||||
String(String),
|
||||
Number(u64),
|
||||
String(String),
|
||||
Number(u64),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct JsonRpcRequest {
|
||||
pub jsonrpc: String,
|
||||
pub id: JsonRpcId,
|
||||
pub method: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub params: Option<HashMap<String, serde_json::Value>>,
|
||||
pub jsonrpc: String,
|
||||
pub id: JsonRpcId,
|
||||
pub method: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub params: Option<HashMap<String, serde_json::Value>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct JsonRpcNotification {
|
||||
pub jsonrpc: String,
|
||||
pub method: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub params: Option<HashMap<String, serde_json::Value>>,
|
||||
pub jsonrpc: String,
|
||||
pub method: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub params: Option<HashMap<String, serde_json::Value>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct JsonRpcError {
|
||||
pub code: i32,
|
||||
pub message: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub data: Option<serde_json::Value>,
|
||||
pub code: i32,
|
||||
pub message: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub data: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct JsonRpcResponse {
|
||||
pub jsonrpc: String,
|
||||
pub id: JsonRpcId,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub result: Option<HashMap<String, serde_json::Value>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub error: Option<JsonRpcError>,
|
||||
pub jsonrpc: String,
|
||||
pub id: JsonRpcId,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub result: Option<HashMap<String, serde_json::Value>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub error: Option<JsonRpcError>,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
use bytes::Bytes;
|
||||
use common::configuration::{LlmProvider, ModelAlias};
|
||||
use common::consts::{ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER, REQUEST_ID_HEADER, TRACE_PARENT_HEADER};
|
||||
use common::consts::{
|
||||
ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER, REQUEST_ID_HEADER, TRACE_PARENT_HEADER,
|
||||
};
|
||||
use common::traces::TraceCollector;
|
||||
use hermesllm::apis::openai_responses::InputParam;
|
||||
use hermesllm::clients::{SupportedAPIsFromClient, SupportedUpstreamAPIs};
|
||||
|
|
@ -14,13 +16,14 @@ use std::sync::Arc;
|
|||
use tokio::sync::RwLock;
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
use crate::router::llm_router::RouterService;
|
||||
use crate::handlers::utils::{create_streaming_response, ObservableStreamProcessor, truncate_message};
|
||||
use crate::handlers::router_chat::router_chat_get_upstream_model;
|
||||
use crate::handlers::utils::{
|
||||
create_streaming_response, truncate_message, ObservableStreamProcessor,
|
||||
};
|
||||
use crate::router::llm_router::RouterService;
|
||||
use crate::state::response_state_processor::ResponsesStateProcessor;
|
||||
use crate::state::{
|
||||
StateStorage, StateStorageError,
|
||||
extract_input_items, retrieve_and_combine_input
|
||||
extract_input_items, retrieve_and_combine_input, StateStorage, StateStorageError,
|
||||
};
|
||||
use crate::tracing::operation_component;
|
||||
|
||||
|
|
@ -39,7 +42,6 @@ pub async fn llm_chat(
|
|||
trace_collector: Arc<TraceCollector>,
|
||||
state_storage: Option<Arc<dyn StateStorage>>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||
|
||||
let request_path = request.uri().path().to_string();
|
||||
let request_headers = request.headers().clone();
|
||||
let request_id = request_headers
|
||||
|
|
@ -74,8 +76,14 @@ pub async fn llm_chat(
|
|||
)) {
|
||||
Ok(request) => request,
|
||||
Err(err) => {
|
||||
warn!("[PLANO_REQ_ID:{}] | FAILURE | Failed to parse request as ProviderRequestType: {}", request_id, err);
|
||||
let err_msg = format!("[PLANO_REQ_ID:{}] | FAILURE | Failed to parse request: {}", request_id, err);
|
||||
warn!(
|
||||
"[PLANO_REQ_ID:{}] | FAILURE | Failed to parse request as ProviderRequestType: {}",
|
||||
request_id, err
|
||||
);
|
||||
let err_msg = format!(
|
||||
"[PLANO_REQ_ID:{}] | FAILURE | Failed to parse request: {}",
|
||||
request_id, err
|
||||
);
|
||||
let mut bad_request = Response::new(full(err_msg));
|
||||
*bad_request.status_mut() = StatusCode::BAD_REQUEST;
|
||||
return Ok(bad_request);
|
||||
|
|
@ -85,7 +93,10 @@ pub async fn llm_chat(
|
|||
// === v1/responses state management: Extract input items early ===
|
||||
let mut original_input_items = Vec::new();
|
||||
let client_api = SupportedAPIsFromClient::from_endpoint(request_path.as_str());
|
||||
let is_responses_api_client = matches!(client_api, Some(SupportedAPIsFromClient::OpenAIResponsesAPI(_)));
|
||||
let is_responses_api_client = matches!(
|
||||
client_api,
|
||||
Some(SupportedAPIsFromClient::OpenAIResponsesAPI(_))
|
||||
);
|
||||
|
||||
// Model alias resolution: update model field in client_request immediately
|
||||
// This ensures all downstream objects use the resolved model
|
||||
|
|
@ -96,20 +107,28 @@ pub async fn llm_chat(
|
|||
|
||||
// Extract tool names and user message preview for span attributes
|
||||
let tool_names = client_request.get_tool_names();
|
||||
let user_message_preview = client_request.get_recent_user_message()
|
||||
let user_message_preview = client_request
|
||||
.get_recent_user_message()
|
||||
.map(|msg| truncate_message(&msg, 50));
|
||||
|
||||
client_request.set_model(resolved_model.clone());
|
||||
if client_request.remove_metadata_key("archgw_preference_config") {
|
||||
debug!("[PLANO_REQ_ID:{}] Removed archgw_preference_config from metadata", request_id);
|
||||
debug!(
|
||||
"[PLANO_REQ_ID:{}] Removed archgw_preference_config from metadata",
|
||||
request_id
|
||||
);
|
||||
}
|
||||
|
||||
// === v1/responses state management: Determine upstream API and combine input if needed ===
|
||||
// Do this BEFORE routing since routing consumes the request
|
||||
// Only process state if state_storage is configured
|
||||
let mut should_manage_state = false;
|
||||
if is_responses_api_client && state_storage.is_some() {
|
||||
if let ProviderRequestType::ResponsesAPIRequest(ref mut responses_req) = client_request {
|
||||
if is_responses_api_client {
|
||||
if let (
|
||||
ProviderRequestType::ResponsesAPIRequest(ref mut responses_req),
|
||||
Some(ref state_store),
|
||||
) = (&mut client_request, &state_storage)
|
||||
{
|
||||
// Extract original input once
|
||||
original_input_items = extract_input_items(&responses_req.input);
|
||||
|
||||
|
|
@ -120,18 +139,22 @@ pub async fn llm_chat(
|
|||
&request_path,
|
||||
&resolved_model,
|
||||
is_streaming_request,
|
||||
).await;
|
||||
)
|
||||
.await;
|
||||
|
||||
let upstream_api = SupportedUpstreamAPIs::from_endpoint(&upstream_path);
|
||||
|
||||
// Only manage state if upstream is NOT OpenAIResponsesAPI (needs translation)
|
||||
should_manage_state = !matches!(upstream_api, Some(SupportedUpstreamAPIs::OpenAIResponsesAPI(_)));
|
||||
should_manage_state = !matches!(
|
||||
upstream_api,
|
||||
Some(SupportedUpstreamAPIs::OpenAIResponsesAPI(_))
|
||||
);
|
||||
|
||||
if should_manage_state {
|
||||
// Retrieve and combine conversation history if previous_response_id exists
|
||||
if let Some(ref prev_resp_id) = responses_req.previous_response_id {
|
||||
match retrieve_and_combine_input(
|
||||
state_storage.as_ref().unwrap().clone(),
|
||||
state_store.clone(),
|
||||
prev_resp_id,
|
||||
original_input_items, // Pass ownership instead of cloning
|
||||
)
|
||||
|
|
@ -166,7 +189,10 @@ pub async fn llm_chat(
|
|||
}
|
||||
}
|
||||
} else {
|
||||
debug!("[PLANO_REQ_ID:{}] | BRIGHT_STAFF | Upstream supports ResponsesAPI natively.", request_id);
|
||||
debug!(
|
||||
"[PLANO_REQ_ID:{}] | BRIGHT_STAFF | Upstream supports ResponsesAPI natively.",
|
||||
request_id
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -177,7 +203,7 @@ pub async fn llm_chat(
|
|||
// Determine routing using the dedicated router_chat module
|
||||
let routing_result = match router_chat_get_upstream_model(
|
||||
router_service,
|
||||
client_request, // Pass the original request - router_chat will convert it
|
||||
client_request, // Pass the original request - router_chat will convert it
|
||||
&request_headers,
|
||||
trace_collector.clone(),
|
||||
&traceparent,
|
||||
|
|
@ -257,7 +283,8 @@ pub async fn llm_chat(
|
|||
user_message_preview,
|
||||
temperature,
|
||||
&llm_providers,
|
||||
).await;
|
||||
)
|
||||
.await;
|
||||
|
||||
// Create base processor for metrics and tracing
|
||||
let base_processor = ObservableStreamProcessor::new(
|
||||
|
|
@ -269,7 +296,11 @@ pub async fn llm_chat(
|
|||
|
||||
// === v1/responses state management: Wrap with ResponsesStateProcessor ===
|
||||
// Only wrap if we need to manage state (client is ResponsesAPI AND upstream is NOT ResponsesAPI AND state_storage is configured)
|
||||
let streaming_response = if should_manage_state && !original_input_items.is_empty() && state_storage.is_some() {
|
||||
let streaming_response = if let (true, false, Some(state_store)) = (
|
||||
should_manage_state,
|
||||
original_input_items.is_empty(),
|
||||
state_storage,
|
||||
) {
|
||||
// Extract Content-Encoding header to handle decompression for state parsing
|
||||
let content_encoding = response_headers
|
||||
.get("content-encoding")
|
||||
|
|
@ -279,7 +310,7 @@ pub async fn llm_chat(
|
|||
// Wrap with state management processor to store state after response completes
|
||||
let state_processor = ResponsesStateProcessor::new(
|
||||
base_processor,
|
||||
state_storage.unwrap(),
|
||||
state_store,
|
||||
original_input_items,
|
||||
resolved_model.clone(),
|
||||
model_name.clone(),
|
||||
|
|
@ -324,6 +355,7 @@ fn resolve_model_alias(
|
|||
}
|
||||
|
||||
/// Builds the LLM span with all required and optional attributes.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn build_llm_span(
|
||||
traceparent: &str,
|
||||
request_path: &str,
|
||||
|
|
@ -337,8 +369,8 @@ async fn build_llm_span(
|
|||
temperature: Option<f32>,
|
||||
llm_providers: &Arc<RwLock<Vec<LlmProvider>>>,
|
||||
) -> common::traces::Span {
|
||||
use common::traces::{SpanBuilder, SpanKind, parse_traceparent};
|
||||
use crate::tracing::{http, llm, OperationNameBuilder};
|
||||
use common::traces::{parse_traceparent, SpanBuilder, SpanKind};
|
||||
|
||||
// Calculate the upstream path based on provider configuration
|
||||
let upstream_path = get_upstream_path(
|
||||
|
|
@ -347,13 +379,14 @@ async fn build_llm_span(
|
|||
request_path,
|
||||
resolved_model,
|
||||
is_streaming,
|
||||
).await;
|
||||
)
|
||||
.await;
|
||||
|
||||
// Build operation name showing path transformation if different
|
||||
let operation_name = if request_path != upstream_path {
|
||||
OperationNameBuilder::new()
|
||||
.with_method("POST")
|
||||
.with_path(&format!("{} >> {}", request_path, upstream_path))
|
||||
.with_path(format!("{} >> {}", request_path, upstream_path))
|
||||
.with_target(resolved_model)
|
||||
.build()
|
||||
} else {
|
||||
|
|
@ -388,7 +421,8 @@ async fn build_llm_span(
|
|||
}
|
||||
|
||||
if let Some(tools) = tool_names {
|
||||
let formatted_tools = tools.iter()
|
||||
let formatted_tools = tools
|
||||
.iter()
|
||||
.map(|name| format!("{}(...)", name))
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
|
|
@ -436,8 +470,7 @@ async fn get_provider_info(
|
|||
|
||||
// First, try to find by model name or provider name
|
||||
let provider = providers_lock.iter().find(|p| {
|
||||
p.model.as_ref().map(|m| m == model_name).unwrap_or(false)
|
||||
|| p.name == model_name
|
||||
p.model.as_ref().map(|m| m == model_name).unwrap_or(false) || p.name == model_name
|
||||
});
|
||||
|
||||
if let Some(provider) = provider {
|
||||
|
|
@ -446,9 +479,7 @@ async fn get_provider_info(
|
|||
return (provider_id, prefix);
|
||||
}
|
||||
|
||||
let default_provider = providers_lock.iter().find(|p| {
|
||||
p.default.unwrap_or(false)
|
||||
});
|
||||
let default_provider = providers_lock.iter().find(|p| p.default.unwrap_or(false));
|
||||
|
||||
if let Some(provider) = default_provider {
|
||||
let provider_id = provider.provider_interface.to_provider_id();
|
||||
|
|
|
|||
|
|
@ -1,13 +1,13 @@
|
|||
pub mod agent_chat_completions;
|
||||
pub mod agent_selector;
|
||||
pub mod llm;
|
||||
pub mod router_chat;
|
||||
pub mod models;
|
||||
pub mod function_calling;
|
||||
pub mod jsonrpc;
|
||||
pub mod llm;
|
||||
pub mod models;
|
||||
pub mod pipeline_processor;
|
||||
pub mod response_handler;
|
||||
pub mod router_chat;
|
||||
pub mod utils;
|
||||
pub mod jsonrpc;
|
||||
|
||||
#[cfg(test)]
|
||||
mod integration_tests;
|
||||
|
|
|
|||
|
|
@ -82,6 +82,7 @@ impl PipelineProcessor {
|
|||
}
|
||||
|
||||
/// Record a span for filter execution
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn record_filter_span(
|
||||
&self,
|
||||
collector: &std::sync::Arc<common::traces::TraceCollector>,
|
||||
|
|
@ -132,6 +133,7 @@ impl PipelineProcessor {
|
|||
}
|
||||
|
||||
/// Record a span for MCP protocol interactions
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn record_agent_filter_span(
|
||||
&self,
|
||||
collector: &std::sync::Arc<common::traces::TraceCollector>,
|
||||
|
|
@ -156,12 +158,12 @@ impl PipelineProcessor {
|
|||
.build();
|
||||
|
||||
let mut span_builder = SpanBuilder::new(&operation_name)
|
||||
.with_span_id(span_id.unwrap_or_else(|| generate_random_span_id()))
|
||||
.with_span_id(span_id.unwrap_or_else(generate_random_span_id))
|
||||
.with_kind(SpanKind::Client)
|
||||
.with_start_time(start_time)
|
||||
.with_end_time(end_time)
|
||||
.with_attribute(http::METHOD, "POST")
|
||||
.with_attribute(http::TARGET, &format!("/mcp ({})", operation.to_string()))
|
||||
.with_attribute(http::TARGET, format!("/mcp ({})", operation))
|
||||
.with_attribute("mcp.operation", operation.to_string())
|
||||
.with_attribute("mcp.agent_id", agent_id.to_string())
|
||||
.with_attribute(
|
||||
|
|
@ -188,6 +190,7 @@ impl PipelineProcessor {
|
|||
}
|
||||
|
||||
/// Process the filter chain of agents (all except the terminal agent)
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub async fn process_filter_chain(
|
||||
&mut self,
|
||||
chat_history: &[Message],
|
||||
|
|
@ -1023,7 +1026,7 @@ mod tests {
|
|||
}
|
||||
});
|
||||
|
||||
let sse_body = format!("event: message\ndata: {}\n\n", rpc_body.to_string());
|
||||
let sse_body = format!("event: message\ndata: {}\n\n", rpc_body);
|
||||
|
||||
let mut server = Server::new_async().await;
|
||||
let _m = server
|
||||
|
|
@ -1061,10 +1064,10 @@ mod tests {
|
|||
.await;
|
||||
|
||||
match result {
|
||||
Err(PipelineError::ClientError { status, body, .. }) => {
|
||||
assert_eq!(status, 400);
|
||||
assert_eq!(body, "bad tool call");
|
||||
}
|
||||
Err(PipelineError::ClientError { status, body, .. }) => {
|
||||
assert_eq!(status, 400);
|
||||
assert_eq!(body, "bad tool call");
|
||||
}
|
||||
_ => panic!("Expected client error when isError flag is set"),
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -133,9 +133,7 @@ impl ResponseHandler {
|
|||
let response_headers = llm_response.headers();
|
||||
let is_sse_streaming = response_headers
|
||||
.get(hyper::header::CONTENT_TYPE)
|
||||
.map_or(false, |v| {
|
||||
v.to_str().unwrap_or("").contains("text/event-stream")
|
||||
});
|
||||
.is_some_and(|v| v.to_str().unwrap_or("").contains("text/event-stream"));
|
||||
|
||||
let response_bytes = llm_response
|
||||
.bytes()
|
||||
|
|
@ -164,7 +162,7 @@ impl ResponseHandler {
|
|||
match transformed_event.provider_response() {
|
||||
Ok(provider_response) => {
|
||||
if let Some(content) = provider_response.content_delta() {
|
||||
accumulated_text.push_str(&content);
|
||||
accumulated_text.push_str(content);
|
||||
} else {
|
||||
info!("No content delta in provider response");
|
||||
}
|
||||
|
|
@ -174,7 +172,7 @@ impl ResponseHandler {
|
|||
}
|
||||
}
|
||||
}
|
||||
return Ok(accumulated_text);
|
||||
Ok(accumulated_text)
|
||||
} else {
|
||||
// If not SSE, treat as regular text response
|
||||
let response_text = String::from_utf8(response_bytes.to_vec()).map_err(|e| {
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
use common::configuration::ModelUsagePreference;
|
||||
use common::consts::{REQUEST_ID_HEADER};
|
||||
use common::traces::{TraceCollector, SpanKind, SpanBuilder, parse_traceparent};
|
||||
use common::consts::REQUEST_ID_HEADER;
|
||||
use common::traces::{parse_traceparent, SpanBuilder, SpanKind, TraceCollector};
|
||||
use hermesllm::clients::endpoints::SupportedUpstreamAPIs;
|
||||
use hermesllm::{ProviderRequest, ProviderRequestType};
|
||||
use hyper::StatusCode;
|
||||
|
|
@ -9,10 +9,10 @@ use std::sync::Arc;
|
|||
use tracing::{debug, info, warn};
|
||||
|
||||
use crate::router::llm_router::RouterService;
|
||||
use crate::tracing::{OperationNameBuilder, operation_component, http, routing};
|
||||
use crate::tracing::{http, operation_component, routing, OperationNameBuilder};
|
||||
|
||||
pub struct RoutingResult {
|
||||
pub model_name: String
|
||||
pub model_name: String,
|
||||
}
|
||||
|
||||
pub struct RoutingError {
|
||||
|
|
@ -24,7 +24,7 @@ impl RoutingError {
|
|||
pub fn internal_error(message: String) -> Self {
|
||||
Self {
|
||||
message,
|
||||
status_code: StatusCode::INTERNAL_SERVER_ERROR
|
||||
status_code: StatusCode::INTERNAL_SERVER_ERROR,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -52,9 +52,7 @@ pub async fn router_chat_get_upstream_model(
|
|||
// Convert to ChatCompletionsRequest for routing (regardless of input type)
|
||||
let chat_request = match ProviderRequestType::try_from((
|
||||
client_request,
|
||||
&SupportedUpstreamAPIs::OpenAIChatCompletions(
|
||||
hermesllm::apis::OpenAIApi::ChatCompletions,
|
||||
),
|
||||
&SupportedUpstreamAPIs::OpenAIChatCompletions(hermesllm::apis::OpenAIApi::ChatCompletions),
|
||||
)) {
|
||||
Ok(ProviderRequestType::ChatCompletionsRequest(req)) => req,
|
||||
Ok(
|
||||
|
|
@ -69,7 +67,10 @@ pub async fn router_chat_get_upstream_model(
|
|||
));
|
||||
}
|
||||
Err(err) => {
|
||||
warn!("Failed to convert request to ChatCompletionsRequest: {}", err);
|
||||
warn!(
|
||||
"Failed to convert request to ChatCompletionsRequest: {}",
|
||||
err
|
||||
);
|
||||
return Err(RoutingError::internal_error(format!(
|
||||
"Failed to convert request: {}",
|
||||
err
|
||||
|
|
@ -151,9 +152,7 @@ pub async fn router_chat_get_upstream_model(
|
|||
)
|
||||
.await;
|
||||
|
||||
Ok(RoutingResult {
|
||||
model_name
|
||||
})
|
||||
Ok(RoutingResult { model_name })
|
||||
}
|
||||
None => {
|
||||
// No route determined, use default model from request
|
||||
|
|
@ -176,7 +175,7 @@ pub async fn router_chat_get_upstream_model(
|
|||
.await;
|
||||
|
||||
Ok(RoutingResult {
|
||||
model_name: default_model
|
||||
model_name: default_model,
|
||||
})
|
||||
}
|
||||
},
|
||||
|
|
@ -194,9 +193,10 @@ pub async fn router_chat_get_upstream_model(
|
|||
)
|
||||
.await;
|
||||
|
||||
Err(RoutingError::internal_error(
|
||||
format!("Failed to determine route: {}", err)
|
||||
))
|
||||
Err(RoutingError::internal_error(format!(
|
||||
"Failed to determine route: {}",
|
||||
err
|
||||
)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -230,7 +230,10 @@ async fn record_routing_span(
|
|||
.with_end_time(std::time::SystemTime::now())
|
||||
.with_attribute(http::METHOD, "POST")
|
||||
.with_attribute(http::TARGET, routing_api_path.to_string())
|
||||
.with_attribute(routing::ROUTE_DETERMINATION_MS, start_time.elapsed().as_millis().to_string());
|
||||
.with_attribute(
|
||||
routing::ROUTE_DETERMINATION_MS,
|
||||
start_time.elapsed().as_millis().to_string(),
|
||||
);
|
||||
|
||||
// Only set parent span ID if it exists (not a root span)
|
||||
if let Some(parent) = parent_span_id {
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
use bytes::Bytes;
|
||||
use common::traces::{Span, Attribute, AttributeValue, TraceCollector, Event};
|
||||
use common::traces::{Attribute, AttributeValue, Event, Span, TraceCollector};
|
||||
use http_body_util::combinators::BoxBody;
|
||||
use http_body_util::StreamBody;
|
||||
use hyper::body::Frame;
|
||||
|
|
@ -11,7 +11,7 @@ use tokio_stream::StreamExt;
|
|||
use tracing::warn;
|
||||
|
||||
// Import tracing constants
|
||||
use crate::tracing::{llm, error};
|
||||
use crate::tracing::{error, llm};
|
||||
|
||||
/// Trait for processing streaming chunks
|
||||
/// Implementors can inject custom logic during streaming (e.g., hallucination detection, logging)
|
||||
|
|
@ -97,7 +97,6 @@ impl StreamProcessor for ObservableStreamProcessor {
|
|||
},
|
||||
});
|
||||
|
||||
|
||||
self.span.attributes.push(Attribute {
|
||||
key: llm::DURATION_MS.to_string(),
|
||||
value: AttributeValue {
|
||||
|
|
@ -119,11 +118,9 @@ impl StreamProcessor for ObservableStreamProcessor {
|
|||
if let Ok(start_time_nanos) = self.span.start_time_unix_nano.parse::<u128>() {
|
||||
// Convert ttft from milliseconds to nanoseconds and add to start time
|
||||
let event_timestamp = start_time_nanos + (ttft * 1_000_000);
|
||||
let mut event = Event::new(llm::TIME_TO_FIRST_TOKEN_MS.to_string(), event_timestamp);
|
||||
event.add_attribute(
|
||||
llm::TIME_TO_FIRST_TOKEN_MS.to_string(),
|
||||
ttft.to_string(),
|
||||
);
|
||||
let mut event =
|
||||
Event::new(llm::TIME_TO_FIRST_TOKEN_MS.to_string(), event_timestamp);
|
||||
event.add_attribute(llm::TIME_TO_FIRST_TOKEN_MS.to_string(), ttft.to_string());
|
||||
|
||||
// Initialize events vector if needed
|
||||
if self.span.events.is_none() {
|
||||
|
|
@ -137,7 +134,8 @@ impl StreamProcessor for ObservableStreamProcessor {
|
|||
}
|
||||
|
||||
// Record the finalized span
|
||||
self.collector.record_span(&self.service_name, self.span.clone());
|
||||
self.collector
|
||||
.record_span(&self.service_name, self.span.clone());
|
||||
}
|
||||
|
||||
fn on_error(&mut self, error_msg: &str) {
|
||||
|
|
@ -173,7 +171,8 @@ impl StreamProcessor for ObservableStreamProcessor {
|
|||
});
|
||||
|
||||
// Record the error span
|
||||
self.collector.record_span(&self.service_name, self.span.clone());
|
||||
self.collector
|
||||
.record_span(&self.service_name, self.span.clone());
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -4,13 +4,15 @@ use brightstaff::handlers::llm::llm_chat;
|
|||
use brightstaff::handlers::models::list_models;
|
||||
use brightstaff::router::llm_router::RouterService;
|
||||
use brightstaff::router::plano_orchestrator::OrchestratorService;
|
||||
use brightstaff::state::StateStorage;
|
||||
use brightstaff::state::postgresql::PostgreSQLConversationStorage;
|
||||
use brightstaff::state::memory::MemoryConversationalStorage;
|
||||
use brightstaff::state::postgresql::PostgreSQLConversationStorage;
|
||||
use brightstaff::state::StateStorage;
|
||||
use brightstaff::utils::tracing::init_tracer;
|
||||
use bytes::Bytes;
|
||||
use common::configuration::{Agent, Configuration};
|
||||
use common::consts::{CHAT_COMPLETIONS_PATH, MESSAGES_PATH, OPENAI_RESPONSES_API_PATH, PLANO_ORCHESTRATOR_MODEL_NAME};
|
||||
use common::consts::{
|
||||
CHAT_COMPLETIONS_PATH, MESSAGES_PATH, OPENAI_RESPONSES_API_PATH, PLANO_ORCHESTRATOR_MODEL_NAME,
|
||||
};
|
||||
use common::traces::TraceCollector;
|
||||
use http_body_util::{combinators::BoxBody, BodyExt, Empty};
|
||||
use hyper::body::Incoming;
|
||||
|
|
@ -105,7 +107,6 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
PLANO_ORCHESTRATOR_MODEL_NAME.to_string(),
|
||||
));
|
||||
|
||||
|
||||
let model_aliases = Arc::new(arch_config.model_aliases.clone());
|
||||
|
||||
// Initialize trace collector and start background flusher
|
||||
|
|
@ -127,33 +128,33 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
// Configurable via arch_config.yaml state_storage section
|
||||
// If not configured, state management is disabled
|
||||
// Environment variables are substituted by envsubst before config is read
|
||||
let state_storage: Option<Arc<dyn StateStorage>> = if let Some(storage_config) = &arch_config.state_storage {
|
||||
let storage: Arc<dyn StateStorage> = match storage_config.storage_type {
|
||||
common::configuration::StateStorageType::Memory => {
|
||||
info!("Initialized conversation state storage: Memory");
|
||||
Arc::new(MemoryConversationalStorage::new())
|
||||
}
|
||||
common::configuration::StateStorageType::Postgres => {
|
||||
let connection_string = storage_config
|
||||
.connection_string
|
||||
.as_ref()
|
||||
.expect("connection_string is required for postgres state_storage");
|
||||
let state_storage: Option<Arc<dyn StateStorage>> =
|
||||
if let Some(storage_config) = &arch_config.state_storage {
|
||||
let storage: Arc<dyn StateStorage> = match storage_config.storage_type {
|
||||
common::configuration::StateStorageType::Memory => {
|
||||
info!("Initialized conversation state storage: Memory");
|
||||
Arc::new(MemoryConversationalStorage::new())
|
||||
}
|
||||
common::configuration::StateStorageType::Postgres => {
|
||||
let connection_string = storage_config
|
||||
.connection_string
|
||||
.as_ref()
|
||||
.expect("connection_string is required for postgres state_storage");
|
||||
|
||||
debug!("Postgres connection string (full): {}", connection_string);
|
||||
info!("Initializing conversation state storage: Postgres");
|
||||
Arc::new(
|
||||
PostgreSQLConversationStorage::new(connection_string.clone())
|
||||
.await
|
||||
.expect("Failed to initialize Postgres state storage"),
|
||||
)
|
||||
}
|
||||
debug!("Postgres connection string (full): {}", connection_string);
|
||||
info!("Initializing conversation state storage: Postgres");
|
||||
Arc::new(
|
||||
PostgreSQLConversationStorage::new(connection_string.clone())
|
||||
.await
|
||||
.expect("Failed to initialize Postgres state storage"),
|
||||
)
|
||||
}
|
||||
};
|
||||
Some(storage)
|
||||
} else {
|
||||
info!("No state_storage configured - conversation state management disabled");
|
||||
None
|
||||
};
|
||||
Some(storage)
|
||||
} else {
|
||||
info!("No state_storage configured - conversation state management disabled");
|
||||
None
|
||||
};
|
||||
|
||||
|
||||
loop {
|
||||
let (stream, _) = listener.accept().await?;
|
||||
|
|
@ -208,12 +209,22 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
}
|
||||
}
|
||||
match (req.method(), path) {
|
||||
(&Method::POST, CHAT_COMPLETIONS_PATH | MESSAGES_PATH | OPENAI_RESPONSES_API_PATH) => {
|
||||
let fully_qualified_url =
|
||||
format!("{}{}", llm_provider_url, path);
|
||||
llm_chat(req, router_service, fully_qualified_url, model_aliases, llm_providers, trace_collector, state_storage)
|
||||
.with_context(parent_cx)
|
||||
.await
|
||||
(
|
||||
&Method::POST,
|
||||
CHAT_COMPLETIONS_PATH | MESSAGES_PATH | OPENAI_RESPONSES_API_PATH,
|
||||
) => {
|
||||
let fully_qualified_url = format!("{}{}", llm_provider_url, path);
|
||||
llm_chat(
|
||||
req,
|
||||
router_service,
|
||||
fully_qualified_url,
|
||||
model_aliases,
|
||||
llm_providers,
|
||||
trace_collector,
|
||||
state_storage,
|
||||
)
|
||||
.with_context(parent_cx)
|
||||
.await
|
||||
}
|
||||
(&Method::POST, "/function_calling") => {
|
||||
let fully_qualified_url =
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ use std::collections::HashMap;
|
|||
|
||||
use common::configuration::{AgentUsagePreference, OrchestrationPreference};
|
||||
use hermesllm::apis::openai::{ChatCompletionsRequest, Message, MessageContent, Role};
|
||||
use serde::{Deserialize, Serialize, ser::Serialize as SerializeTrait};
|
||||
use serde::{ser::Serialize as SerializeTrait, Deserialize, Serialize};
|
||||
use tracing::{debug, warn};
|
||||
|
||||
use super::orchestrator_model::{OrchestratorModel, OrchestratorModelError};
|
||||
|
|
@ -144,7 +144,7 @@ impl OrchestratorModelV1 {
|
|||
// Format routes: each route as JSON on its own line with standard spacing
|
||||
let agent_orchestration_json_str = agent_orchestration_values
|
||||
.iter()
|
||||
.map(|pref| to_spaced_json(pref))
|
||||
.map(to_spaced_json)
|
||||
.collect::<Vec<String>>()
|
||||
.join("\n");
|
||||
let agent_orchestration_to_model_map: HashMap<String, String> = agent_orchestrations
|
||||
|
|
@ -238,24 +238,26 @@ impl OrchestratorModel for OrchestratorModelV1 {
|
|||
let selected_conversation_list = selected_messages_list_reversed
|
||||
.iter()
|
||||
.rev()
|
||||
.map(|message| {
|
||||
Message {
|
||||
role: message.role.clone(),
|
||||
content: MessageContent::Text(message.content.to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
}
|
||||
.map(|message| Message {
|
||||
role: message.role.clone(),
|
||||
content: MessageContent::Text(message.content.to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
})
|
||||
.collect::<Vec<Message>>();
|
||||
|
||||
// Generate the orchestrator request message based on the usage preferences.
|
||||
// If preferences are passed in request then we use them;
|
||||
// Otherwise, we use the default orchestration modelpreferences.
|
||||
let orchestrator_message = match convert_to_orchestrator_preferences(usage_preferences_from_request) {
|
||||
Some(prefs) => generate_orchestrator_message(&prefs, &selected_conversation_list),
|
||||
None => generate_orchestrator_message(&self.agent_orchestration_json_str, &selected_conversation_list),
|
||||
};
|
||||
let orchestrator_message =
|
||||
match convert_to_orchestrator_preferences(usage_preferences_from_request) {
|
||||
Some(prefs) => generate_orchestrator_message(&prefs, &selected_conversation_list),
|
||||
None => generate_orchestrator_message(
|
||||
&self.agent_orchestration_json_str,
|
||||
&selected_conversation_list,
|
||||
),
|
||||
};
|
||||
|
||||
ChatCompletionsRequest {
|
||||
model: self.orchestration_model.clone(),
|
||||
|
|
@ -280,7 +282,8 @@ impl OrchestratorModel for OrchestratorModelV1 {
|
|||
return Ok(None);
|
||||
}
|
||||
let orchestrator_resp_fixed = fix_json_response(content);
|
||||
let orchestrator_response: AgentOrchestratorResponse = serde_json::from_str(orchestrator_resp_fixed.as_str())?;
|
||||
let orchestrator_response: AgentOrchestratorResponse =
|
||||
serde_json::from_str(orchestrator_resp_fixed.as_str())?;
|
||||
|
||||
let selected_routes = orchestrator_response.route.unwrap_or_default();
|
||||
|
||||
|
|
@ -320,7 +323,11 @@ impl OrchestratorModel for OrchestratorModelV1 {
|
|||
} else {
|
||||
// If no usage preferences are passed in request then use the default orchestration model preferences
|
||||
for selected_route in valid_routes {
|
||||
if let Some(model) = self.agent_orchestration_to_model_map.get(&selected_route).cloned() {
|
||||
if let Some(model) = self
|
||||
.agent_orchestration_to_model_map
|
||||
.get(&selected_route)
|
||||
.cloned()
|
||||
{
|
||||
result.push((selected_route, model));
|
||||
} else {
|
||||
warn!(
|
||||
|
|
@ -375,7 +382,7 @@ fn convert_to_orchestrator_preferences(
|
|||
// Format routes: each route as JSON on its own line with standard spacing
|
||||
let routes_str = orchestration_preferences
|
||||
.iter()
|
||||
.map(|pref| to_spaced_json(pref))
|
||||
.map(to_spaced_json)
|
||||
.collect::<Vec<String>>()
|
||||
.join("\n");
|
||||
|
||||
|
|
@ -425,7 +432,10 @@ mod tests {
|
|||
// CRITICAL: Test that colons inside string values are NOT modified
|
||||
let with_colon = serde_json::json!({"name": "foo:bar", "url": "http://example.com"});
|
||||
let result = to_spaced_json(&with_colon);
|
||||
assert_eq!(result, r#"{"name": "foo:bar", "url": "http://example.com"}"#);
|
||||
assert_eq!(
|
||||
result,
|
||||
r#"{"name": "foo:bar", "url": "http://example.com"}"#
|
||||
);
|
||||
|
||||
// Test empty object and array
|
||||
let empty_obj = serde_json::json!({});
|
||||
|
|
@ -446,7 +456,8 @@ mod tests {
|
|||
});
|
||||
let result = to_spaced_json(&complex);
|
||||
// Verify URLs with colons are preserved correctly
|
||||
assert!(result.contains(r#""urls": ["https://api.example.com:8080/path", "file:///local/path"]"#));
|
||||
assert!(result
|
||||
.contains(r#""urls": ["https://api.example.com:8080/path", "file:///local/path"]"#));
|
||||
// Verify spacing format
|
||||
assert!(result.contains(r#""type": "object""#));
|
||||
assert!(result.contains(r#""properties": {}"#));
|
||||
|
|
@ -497,10 +508,16 @@ If no routes are needed, return an empty list for `route`.
|
|||
]
|
||||
}
|
||||
"#;
|
||||
let agent_orchestrations =
|
||||
serde_json::from_str::<HashMap<String, Vec<OrchestrationPreference>>>(orchestrations_str).unwrap();
|
||||
let agent_orchestrations = serde_json::from_str::<
|
||||
HashMap<String, Vec<OrchestrationPreference>>,
|
||||
>(orchestrations_str)
|
||||
.unwrap();
|
||||
let orchestration_model = "test-model".to_string();
|
||||
let orchestrator = OrchestratorModelV1::new(agent_orchestrations, orchestration_model.clone(), usize::MAX);
|
||||
let orchestrator = OrchestratorModelV1::new(
|
||||
agent_orchestrations,
|
||||
orchestration_model.clone(),
|
||||
usize::MAX,
|
||||
);
|
||||
|
||||
let conversation_str = r#"
|
||||
[
|
||||
|
|
@ -568,7 +585,11 @@ If no routes are needed, return an empty list for `route`.
|
|||
// Empty orchestrations map - not used when usage_preferences are provided
|
||||
let agent_orchestrations: HashMap<String, Vec<OrchestrationPreference>> = HashMap::new();
|
||||
let orchestration_model = "test-model".to_string();
|
||||
let orchestrator = OrchestratorModelV1::new(agent_orchestrations, orchestration_model.clone(), usize::MAX);
|
||||
let orchestrator = OrchestratorModelV1::new(
|
||||
agent_orchestrations,
|
||||
orchestration_model.clone(),
|
||||
usize::MAX,
|
||||
);
|
||||
|
||||
let conversation_str = r#"
|
||||
[
|
||||
|
|
@ -640,10 +661,13 @@ If no routes are needed, return an empty list for `route`.
|
|||
]
|
||||
}
|
||||
"#;
|
||||
let agent_orchestrations =
|
||||
serde_json::from_str::<HashMap<String, Vec<OrchestrationPreference>>>(orchestrations_str).unwrap();
|
||||
let agent_orchestrations = serde_json::from_str::<
|
||||
HashMap<String, Vec<OrchestrationPreference>>,
|
||||
>(orchestrations_str)
|
||||
.unwrap();
|
||||
let orchestration_model = "test-model".to_string();
|
||||
let orchestrator = OrchestratorModelV1::new(agent_orchestrations, orchestration_model.clone(), 235);
|
||||
let orchestrator =
|
||||
OrchestratorModelV1::new(agent_orchestrations, orchestration_model.clone(), 235);
|
||||
|
||||
let conversation_str = r#"
|
||||
[
|
||||
|
|
@ -709,11 +733,14 @@ If no routes are needed, return an empty list for `route`.
|
|||
]
|
||||
}
|
||||
"#;
|
||||
let agent_orchestrations =
|
||||
serde_json::from_str::<HashMap<String, Vec<OrchestrationPreference>>>(orchestrations_str).unwrap();
|
||||
let agent_orchestrations = serde_json::from_str::<
|
||||
HashMap<String, Vec<OrchestrationPreference>>,
|
||||
>(orchestrations_str)
|
||||
.unwrap();
|
||||
|
||||
let orchestration_model = "test-model".to_string();
|
||||
let orchestrator = OrchestratorModelV1::new(agent_orchestrations, orchestration_model.clone(), 200);
|
||||
let orchestrator =
|
||||
OrchestratorModelV1::new(agent_orchestrations, orchestration_model.clone(), 200);
|
||||
|
||||
let conversation_str = r#"
|
||||
[
|
||||
|
|
@ -787,10 +814,13 @@ If no routes are needed, return an empty list for `route`.
|
|||
]
|
||||
}
|
||||
"#;
|
||||
let agent_orchestrations =
|
||||
serde_json::from_str::<HashMap<String, Vec<OrchestrationPreference>>>(orchestrations_str).unwrap();
|
||||
let agent_orchestrations = serde_json::from_str::<
|
||||
HashMap<String, Vec<OrchestrationPreference>>,
|
||||
>(orchestrations_str)
|
||||
.unwrap();
|
||||
let orchestration_model = "test-model".to_string();
|
||||
let orchestrator = OrchestratorModelV1::new(agent_orchestrations, orchestration_model.clone(), 230);
|
||||
let orchestrator =
|
||||
OrchestratorModelV1::new(agent_orchestrations, orchestration_model.clone(), 230);
|
||||
|
||||
let conversation_str = r#"
|
||||
[
|
||||
|
|
@ -871,10 +901,16 @@ If no routes are needed, return an empty list for `route`.
|
|||
]
|
||||
}
|
||||
"#;
|
||||
let agent_orchestrations =
|
||||
serde_json::from_str::<HashMap<String, Vec<OrchestrationPreference>>>(orchestrations_str).unwrap();
|
||||
let agent_orchestrations = serde_json::from_str::<
|
||||
HashMap<String, Vec<OrchestrationPreference>>,
|
||||
>(orchestrations_str)
|
||||
.unwrap();
|
||||
let orchestration_model = "test-model".to_string();
|
||||
let orchestrator = OrchestratorModelV1::new(agent_orchestrations, orchestration_model.clone(), usize::MAX);
|
||||
let orchestrator = OrchestratorModelV1::new(
|
||||
agent_orchestrations,
|
||||
orchestration_model.clone(),
|
||||
usize::MAX,
|
||||
);
|
||||
|
||||
let conversation_str = r#"
|
||||
[
|
||||
|
|
@ -957,10 +993,16 @@ If no routes are needed, return an empty list for `route`.
|
|||
]
|
||||
}
|
||||
"#;
|
||||
let agent_orchestrations =
|
||||
serde_json::from_str::<HashMap<String, Vec<OrchestrationPreference>>>(orchestrations_str).unwrap();
|
||||
let agent_orchestrations = serde_json::from_str::<
|
||||
HashMap<String, Vec<OrchestrationPreference>>,
|
||||
>(orchestrations_str)
|
||||
.unwrap();
|
||||
let orchestration_model = "test-model".to_string();
|
||||
let orchestrator = OrchestratorModelV1::new(agent_orchestrations, orchestration_model.clone(), usize::MAX);
|
||||
let orchestrator = OrchestratorModelV1::new(
|
||||
agent_orchestrations,
|
||||
orchestration_model.clone(),
|
||||
usize::MAX,
|
||||
);
|
||||
|
||||
let conversation_str = r#"
|
||||
[
|
||||
|
|
@ -1034,10 +1076,13 @@ If no routes are needed, return an empty list for `route`.
|
|||
]
|
||||
}
|
||||
"#;
|
||||
let agent_orchestrations =
|
||||
serde_json::from_str::<HashMap<String, Vec<OrchestrationPreference>>>(orchestrations_str).unwrap();
|
||||
let agent_orchestrations = serde_json::from_str::<
|
||||
HashMap<String, Vec<OrchestrationPreference>>,
|
||||
>(orchestrations_str)
|
||||
.unwrap();
|
||||
|
||||
let orchestrator = OrchestratorModelV1::new(agent_orchestrations, "test-model".to_string(), 2000);
|
||||
let orchestrator =
|
||||
OrchestratorModelV1::new(agent_orchestrations, "test-model".to_string(), 2000);
|
||||
|
||||
// Case 1: Valid JSON with single route in array
|
||||
let input = r#"{"route": ["Image generation"]}"#;
|
||||
|
|
|
|||
|
|
@ -34,10 +34,7 @@ pub enum OrchestrationError {
|
|||
pub type Result<T> = std::result::Result<T, OrchestrationError>;
|
||||
|
||||
impl OrchestratorService {
|
||||
pub fn new(
|
||||
orchestrator_url: String,
|
||||
orchestration_model_name: String,
|
||||
) -> Self {
|
||||
pub fn new(orchestrator_url: String, orchestration_model_name: String) -> Self {
|
||||
// Empty agent orchestrations - will be provided via usage_preferences in requests
|
||||
let agent_orchestrations: HashMap<String, Vec<OrchestrationPreference>> = HashMap::new();
|
||||
|
||||
|
|
|
|||
|
|
@ -85,13 +85,19 @@ impl StateStorage for MemoryConversationalStorage {
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use hermesllm::apis::openai_responses::{InputItem, InputMessage, MessageRole, InputContent, MessageContent};
|
||||
use hermesllm::apis::openai_responses::{
|
||||
InputContent, InputItem, InputMessage, MessageContent, MessageRole,
|
||||
};
|
||||
|
||||
fn create_test_state(response_id: &str, num_messages: usize) -> OpenAIConversationState {
|
||||
let mut input_items = Vec::new();
|
||||
for i in 0..num_messages {
|
||||
input_items.push(InputItem::Message(InputMessage {
|
||||
role: if i % 2 == 0 { MessageRole::User } else { MessageRole::Assistant },
|
||||
role: if i % 2 == 0 {
|
||||
MessageRole::User
|
||||
} else {
|
||||
MessageRole::Assistant
|
||||
},
|
||||
content: MessageContent::Items(vec![InputContent::InputText {
|
||||
text: format!("Message {}", i),
|
||||
}]),
|
||||
|
|
@ -252,7 +258,9 @@ mod tests {
|
|||
let merged = storage.merge(&prev_state, current_input);
|
||||
|
||||
// Verify order: prev messages first, then current
|
||||
let InputItem::Message(msg) = &merged[0] else { panic!("Expected Message") };
|
||||
let InputItem::Message(msg) = &merged[0] else {
|
||||
panic!("Expected Message")
|
||||
};
|
||||
match &msg.content {
|
||||
MessageContent::Items(items) => match &items[0] {
|
||||
InputContent::InputText { text } => assert_eq!(text, "Message 0"),
|
||||
|
|
@ -261,7 +269,9 @@ mod tests {
|
|||
_ => panic!("Expected MessageContent::Items"),
|
||||
}
|
||||
|
||||
let InputItem::Message(msg) = &merged[2] else { panic!("Expected Message") };
|
||||
let InputItem::Message(msg) = &merged[2] else {
|
||||
panic!("Expected Message")
|
||||
};
|
||||
match &msg.content {
|
||||
MessageContent::Items(items) => match &items[0] {
|
||||
InputContent::InputText { text } => assert_eq!(text, "Message 2"),
|
||||
|
|
@ -404,7 +414,8 @@ mod tests {
|
|||
let current_input = vec![InputItem::Message(InputMessage {
|
||||
role: MessageRole::User,
|
||||
content: MessageContent::Items(vec![InputContent::InputText {
|
||||
text: "Function result: {\"temperature\": 72, \"condition\": \"sunny\"}".to_string(),
|
||||
text: "Function result: {\"temperature\": 72, \"condition\": \"sunny\"}"
|
||||
.to_string(),
|
||||
}]),
|
||||
})];
|
||||
|
||||
|
|
@ -415,7 +426,9 @@ mod tests {
|
|||
assert_eq!(merged.len(), 3);
|
||||
|
||||
// Verify the order and content
|
||||
let InputItem::Message(msg1) = &merged[0] else { panic!("Expected Message") };
|
||||
let InputItem::Message(msg1) = &merged[0] else {
|
||||
panic!("Expected Message")
|
||||
};
|
||||
assert!(matches!(msg1.role, MessageRole::User));
|
||||
match &msg1.content {
|
||||
MessageContent::Items(items) => match &items[0] {
|
||||
|
|
@ -427,7 +440,9 @@ mod tests {
|
|||
_ => panic!("Expected MessageContent::Items"),
|
||||
}
|
||||
|
||||
let InputItem::Message(msg2) = &merged[1] else { panic!("Expected Message") };
|
||||
let InputItem::Message(msg2) = &merged[1] else {
|
||||
panic!("Expected Message")
|
||||
};
|
||||
assert!(matches!(msg2.role, MessageRole::Assistant));
|
||||
match &msg2.content {
|
||||
MessageContent::Items(items) => match &items[0] {
|
||||
|
|
@ -439,7 +454,9 @@ mod tests {
|
|||
_ => panic!("Expected MessageContent::Items"),
|
||||
}
|
||||
|
||||
let InputItem::Message(msg3) = &merged[2] else { panic!("Expected Message") };
|
||||
let InputItem::Message(msg3) = &merged[2] else {
|
||||
panic!("Expected Message")
|
||||
};
|
||||
assert!(matches!(msg3.role, MessageRole::User));
|
||||
match &msg3.content {
|
||||
MessageContent::Items(items) => match &items[0] {
|
||||
|
|
@ -508,11 +525,15 @@ mod tests {
|
|||
assert_eq!(merged.len(), 5);
|
||||
|
||||
// Verify first item is original user message
|
||||
let InputItem::Message(first) = &merged[0] else { panic!("Expected Message") };
|
||||
let InputItem::Message(first) = &merged[0] else {
|
||||
panic!("Expected Message")
|
||||
};
|
||||
assert!(matches!(first.role, MessageRole::User));
|
||||
|
||||
// Verify last two are function outputs
|
||||
let InputItem::Message(second_last) = &merged[3] else { panic!("Expected Message") };
|
||||
let InputItem::Message(second_last) = &merged[3] else {
|
||||
panic!("Expected Message")
|
||||
};
|
||||
assert!(matches!(second_last.role, MessageRole::User));
|
||||
match &second_last.content {
|
||||
MessageContent::Items(items) => match &items[0] {
|
||||
|
|
@ -522,7 +543,9 @@ mod tests {
|
|||
_ => panic!("Expected MessageContent::Items"),
|
||||
}
|
||||
|
||||
let InputItem::Message(last) = &merged[4] else { panic!("Expected Message") };
|
||||
let InputItem::Message(last) = &merged[4] else {
|
||||
panic!("Expected Message")
|
||||
};
|
||||
assert!(matches!(last.role, MessageRole::User));
|
||||
match &last.content {
|
||||
MessageContent::Items(items) => match &items[0] {
|
||||
|
|
@ -590,7 +613,9 @@ mod tests {
|
|||
assert_eq!(merged.len(), 5);
|
||||
|
||||
// Verify the entire conversation flow is preserved
|
||||
let InputItem::Message(first) = &merged[0] else { panic!("Expected Message") };
|
||||
let InputItem::Message(first) = &merged[0] else {
|
||||
panic!("Expected Message")
|
||||
};
|
||||
match &first.content {
|
||||
MessageContent::Items(items) => match &items[0] {
|
||||
InputContent::InputText { text } => assert!(text.contains("What's the weather")),
|
||||
|
|
@ -599,7 +624,9 @@ mod tests {
|
|||
_ => panic!("Expected MessageContent::Items"),
|
||||
}
|
||||
|
||||
let InputItem::Message(last) = &merged[4] else { panic!("Expected Message") };
|
||||
let InputItem::Message(last) = &merged[4] else {
|
||||
panic!("Expected Message")
|
||||
};
|
||||
match &last.content {
|
||||
MessageContent::Items(items) => match &items[0] {
|
||||
InputContent::InputText { text } => assert!(text.contains("umbrella")),
|
||||
|
|
|
|||
|
|
@ -1,14 +1,16 @@
|
|||
use async_trait::async_trait;
|
||||
use hermesllm::apis::openai_responses::{InputItem, InputMessage, InputContent, MessageContent, MessageRole, InputParam};
|
||||
use hermesllm::apis::openai_responses::{
|
||||
InputContent, InputItem, InputMessage, InputParam, MessageContent, MessageRole,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::error::Error;
|
||||
use std::fmt;
|
||||
use std::sync::Arc;
|
||||
use tracing::{debug};
|
||||
use tracing::debug;
|
||||
|
||||
pub mod memory;
|
||||
pub mod response_state_processor;
|
||||
pub mod postgresql;
|
||||
pub mod response_state_processor;
|
||||
|
||||
/// Represents the conversational state for a v1/responses request
|
||||
/// Contains the complete input/output history that can be restored
|
||||
|
|
@ -47,7 +49,9 @@ pub enum StateStorageError {
|
|||
impl fmt::Display for StateStorageError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
StateStorageError::NotFound(id) => write!(f, "Conversation state not found for response_id: {}", id),
|
||||
StateStorageError::NotFound(id) => {
|
||||
write!(f, "Conversation state not found for response_id: {}", id)
|
||||
}
|
||||
StateStorageError::StorageError(msg) => write!(f, "Storage error: {}", msg),
|
||||
StateStorageError::SerializationError(msg) => write!(f, "Serialization error: {}", msg),
|
||||
}
|
||||
|
|
@ -96,8 +100,6 @@ pub trait StateStorage: Send + Sync {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
/// Storage backend type enum
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum StorageBackend {
|
||||
|
|
@ -106,7 +108,7 @@ pub enum StorageBackend {
|
|||
}
|
||||
|
||||
impl StorageBackend {
|
||||
pub fn from_str(s: &str) -> Option<Self> {
|
||||
pub fn parse_backend(s: &str) -> Option<Self> {
|
||||
match s.to_lowercase().as_str() {
|
||||
"memory" => Some(StorageBackend::Memory),
|
||||
"supabase" => Some(StorageBackend::Supabase),
|
||||
|
|
@ -139,7 +141,6 @@ pub async fn retrieve_and_combine_input(
|
|||
previous_response_id: &str,
|
||||
current_input: Vec<InputItem>,
|
||||
) -> Result<Vec<InputItem>, StateStorageError> {
|
||||
|
||||
// First get the previous state
|
||||
let prev_state = storage.get(previous_response_id).await?;
|
||||
let combined_input = storage.merge(&prev_state, current_input);
|
||||
|
|
|
|||
|
|
@ -149,13 +149,12 @@ impl StateStorage for PostgreSQLConversationStorage {
|
|||
let provider: String = row.get("provider");
|
||||
|
||||
// Deserialize input_items from JSONB
|
||||
let input_items =
|
||||
serde_json::from_value(input_items_json).map_err(|e| {
|
||||
StateStorageError::StorageError(format!(
|
||||
"Failed to deserialize input_items: {}",
|
||||
e
|
||||
))
|
||||
})?;
|
||||
let input_items = serde_json::from_value(input_items_json).map_err(|e| {
|
||||
StateStorageError::StorageError(format!(
|
||||
"Failed to deserialize input_items: {}",
|
||||
e
|
||||
))
|
||||
})?;
|
||||
|
||||
Ok(OpenAIConversationState {
|
||||
response_id,
|
||||
|
|
@ -230,7 +229,9 @@ Run that SQL file against your database before using this storage backend.
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use hermesllm::apis::openai_responses::{InputContent, InputItem, InputMessage, MessageContent, MessageRole};
|
||||
use hermesllm::apis::openai_responses::{
|
||||
InputContent, InputItem, InputMessage, MessageContent, MessageRole,
|
||||
};
|
||||
|
||||
fn create_test_state(response_id: &str) -> OpenAIConversationState {
|
||||
OpenAIConversationState {
|
||||
|
|
@ -320,7 +321,10 @@ mod tests {
|
|||
|
||||
let result = storage.get("nonexistent_id").await;
|
||||
assert!(result.is_err());
|
||||
assert!(matches!(result.unwrap_err(), StateStorageError::NotFound(_)));
|
||||
assert!(matches!(
|
||||
result.unwrap_err(),
|
||||
StateStorageError::NotFound(_)
|
||||
));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
|
@ -372,7 +376,10 @@ mod tests {
|
|||
|
||||
let result = storage.delete("nonexistent_id").await;
|
||||
assert!(result.is_err());
|
||||
assert!(matches!(result.unwrap_err(), StateStorageError::NotFound(_)));
|
||||
assert!(matches!(
|
||||
result.unwrap_err(),
|
||||
StateStorageError::NotFound(_)
|
||||
));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
|
@ -423,9 +430,13 @@ mod tests {
|
|||
|
||||
println!("✅ Data written to Supabase!");
|
||||
println!("Check your Supabase dashboard:");
|
||||
println!(" SELECT * FROM conversation_states WHERE response_id = 'manual_test_verification';");
|
||||
println!(
|
||||
" SELECT * FROM conversation_states WHERE response_id = 'manual_test_verification';"
|
||||
);
|
||||
println!("\nTo cleanup, run:");
|
||||
println!(" DELETE FROM conversation_states WHERE response_id = 'manual_test_verification';");
|
||||
println!(
|
||||
" DELETE FROM conversation_states WHERE response_id = 'manual_test_verification';"
|
||||
);
|
||||
|
||||
// DON'T cleanup - leave it for manual verification
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,13 +1,11 @@
|
|||
use bytes::Bytes;
|
||||
use flate2::read::GzDecoder;
|
||||
use hermesllm::apis::openai_responses::{
|
||||
InputItem, OutputItem, ResponsesAPIStreamEvent,
|
||||
};
|
||||
use hermesllm::apis::openai_responses::{InputItem, OutputItem, ResponsesAPIStreamEvent};
|
||||
use hermesllm::apis::streaming_shapes::sse::SseStreamIter;
|
||||
use hermesllm::transforms::response::output_to_input::outputs_to_inputs;
|
||||
use std::io::Read;
|
||||
use std::sync::Arc;
|
||||
use tracing::{info, debug, warn};
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
use crate::handlers::utils::StreamProcessor;
|
||||
use crate::state::{OpenAIConversationState, StateStorage};
|
||||
|
|
@ -53,6 +51,7 @@ pub struct ResponsesStateProcessor<P: StreamProcessor> {
|
|||
}
|
||||
|
||||
impl<P: StreamProcessor> ResponsesStateProcessor<P> {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn new(
|
||||
inner: P,
|
||||
storage: Arc<dyn StateStorage>,
|
||||
|
|
@ -139,20 +138,19 @@ impl<P: StreamProcessor> ResponsesStateProcessor<P> {
|
|||
for event in sse_iter {
|
||||
// Only process data lines (skip event-only lines)
|
||||
if let Some(data_str) = &event.data {
|
||||
// Try to parse as ResponsesAPIStreamEvent
|
||||
if let Ok(stream_event) = serde_json::from_str::<ResponsesAPIStreamEvent>(data_str) {
|
||||
// Check if this is a ResponseCompleted event
|
||||
if let ResponsesAPIStreamEvent::ResponseCompleted { response, .. } = stream_event {
|
||||
info!(
|
||||
"[PLANO_REQ_ID:{}] | STATE_PROCESSOR | Captured streaming response.completed: response_id={}, output_items={}",
|
||||
self.request_id,
|
||||
response.id,
|
||||
response.output.len()
|
||||
);
|
||||
self.response_id = Some(response.id.clone());
|
||||
self.output_items = Some(response.output.clone());
|
||||
return; // Found what we need, exit early
|
||||
}
|
||||
// Try to parse as ResponsesAPIStreamEvent and check if it's a ResponseCompleted event
|
||||
if let Ok(ResponsesAPIStreamEvent::ResponseCompleted { response, .. }) =
|
||||
serde_json::from_str::<ResponsesAPIStreamEvent>(data_str)
|
||||
{
|
||||
info!(
|
||||
"[PLANO_REQ_ID:{}] | STATE_PROCESSOR | Captured streaming response.completed: response_id={}, output_items={}",
|
||||
self.request_id,
|
||||
response.id,
|
||||
response.output.len()
|
||||
);
|
||||
self.response_id = Some(response.id.clone());
|
||||
self.output_items = Some(response.output.clone());
|
||||
return; // Found what we need, exit early
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -172,7 +170,9 @@ impl<P: StreamProcessor> ResponsesStateProcessor<P> {
|
|||
let decompressed = self.decompress_buffer();
|
||||
|
||||
// Parse complete JSON response
|
||||
match serde_json::from_slice::<hermesllm::apis::openai_responses::ResponsesAPIResponse>(&decompressed) {
|
||||
match serde_json::from_slice::<hermesllm::apis::openai_responses::ResponsesAPIResponse>(
|
||||
&decompressed,
|
||||
) {
|
||||
Ok(response) => {
|
||||
info!(
|
||||
"[PLANO_REQ_ID:{}] | STATE_PROCESSOR | Captured non-streaming response: response_id={}, output_items={}",
|
||||
|
|
|
|||
|
|
@ -2,11 +2,9 @@
|
|||
///
|
||||
/// This module defines standard attribute keys following OTEL semantic conventions.
|
||||
/// See: https://opentelemetry.io/docs/specs/semconv/
|
||||
|
||||
// =============================================================================
|
||||
// Span Attributes - HTTP
|
||||
// =============================================================================
|
||||
|
||||
/// Semantic conventions for HTTP-related span attributes
|
||||
pub mod http {
|
||||
/// HTTP request method
|
||||
|
|
|
|||
|
|
@ -1,3 +1,3 @@
|
|||
mod constants;
|
||||
|
||||
pub use constants::{OperationNameBuilder, operation_component, http, llm, error, routing};
|
||||
pub use constants::{error, http, llm, operation_component, routing, OperationNameBuilder};
|
||||
|
|
|
|||
|
|
@ -295,11 +295,14 @@ impl serde::Serialize for OrchestrationPreference {
|
|||
let mut state = serializer.serialize_struct("OrchestrationPreference", 3)?;
|
||||
state.serialize_field("name", &self.name)?;
|
||||
state.serialize_field("description", &self.description)?;
|
||||
state.serialize_field("parameters", &serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": []
|
||||
}))?;
|
||||
state.serialize_field(
|
||||
"parameters",
|
||||
&serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": []
|
||||
}),
|
||||
)?;
|
||||
state.end()
|
||||
}
|
||||
}
|
||||
|
|
@ -489,7 +492,10 @@ mod test {
|
|||
assert_eq!(config.version, "v0.3.0");
|
||||
|
||||
if let Some(prompt_targets) = &config.prompt_targets {
|
||||
assert!(!prompt_targets.is_empty(), "prompt_targets should not be empty if present");
|
||||
assert!(
|
||||
!prompt_targets.is_empty(),
|
||||
"prompt_targets should not be empty if present"
|
||||
);
|
||||
}
|
||||
|
||||
if let Some(tracing) = config.tracing.as_ref() {
|
||||
|
|
@ -510,19 +516,48 @@ mod test {
|
|||
.expect("reference config file not found");
|
||||
let config: super::Configuration = serde_yaml::from_str(&ref_config).unwrap();
|
||||
if let Some(prompt_targets) = &config.prompt_targets {
|
||||
if let Some(prompt_target) = prompt_targets.iter().find(|p| p.name == "reboot_network_device") {
|
||||
if let Some(prompt_target) = prompt_targets
|
||||
.iter()
|
||||
.find(|p| p.name == "reboot_network_device")
|
||||
{
|
||||
let chat_completion_tool: super::ChatCompletionTool = prompt_target.into();
|
||||
assert_eq!(chat_completion_tool.tool_type, ToolType::Function);
|
||||
assert_eq!(chat_completion_tool.function.name, "reboot_network_device");
|
||||
assert_eq!(chat_completion_tool.function.description, "Reboot a specific network device");
|
||||
assert_eq!(
|
||||
chat_completion_tool.function.description,
|
||||
"Reboot a specific network device"
|
||||
);
|
||||
assert_eq!(chat_completion_tool.function.parameters.properties.len(), 2);
|
||||
assert!(chat_completion_tool.function.parameters.properties.contains_key("device_id"));
|
||||
let device_id_param = chat_completion_tool.function.parameters.properties.get("device_id").unwrap();
|
||||
assert_eq!(device_id_param.parameter_type, crate::api::open_ai::ParameterType::String);
|
||||
assert_eq!(device_id_param.description, "Identifier of the network device to reboot.".to_string());
|
||||
assert!(chat_completion_tool
|
||||
.function
|
||||
.parameters
|
||||
.properties
|
||||
.contains_key("device_id"));
|
||||
let device_id_param = chat_completion_tool
|
||||
.function
|
||||
.parameters
|
||||
.properties
|
||||
.get("device_id")
|
||||
.unwrap();
|
||||
assert_eq!(
|
||||
device_id_param.parameter_type,
|
||||
crate::api::open_ai::ParameterType::String
|
||||
);
|
||||
assert_eq!(
|
||||
device_id_param.description,
|
||||
"Identifier of the network device to reboot.".to_string()
|
||||
);
|
||||
assert_eq!(device_id_param.required, Some(true));
|
||||
let confirmation_param = chat_completion_tool.function.parameters.properties.get("confirmation").unwrap();
|
||||
assert_eq!(confirmation_param.parameter_type, crate::api::open_ai::ParameterType::Bool);
|
||||
let confirmation_param = chat_completion_tool
|
||||
.function
|
||||
.parameters
|
||||
.properties
|
||||
.get("confirmation")
|
||||
.unwrap();
|
||||
assert_eq!(
|
||||
confirmation_param.parameter_type,
|
||||
crate::api::open_ai::ParameterType::Bool
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -32,6 +32,6 @@ pub const OTEL_COLLECTOR_HTTP: &str = "opentelemetry_collector_http";
|
|||
pub const OTEL_POST_PATH: &str = "/v1/traces";
|
||||
pub const LLM_ROUTE_HEADER: &str = "x-arch-llm-route";
|
||||
pub const ENVOY_RETRY_HEADER: &str = "x-envoy-max-retries";
|
||||
pub const BRIGHT_STAFF_SERVICE_NAME : &str = "brightstaff";
|
||||
pub const BRIGHT_STAFF_SERVICE_NAME: &str = "brightstaff";
|
||||
pub const PLANO_ORCHESTRATOR_MODEL_NAME: &str = "Plano-Orchestrator";
|
||||
pub const ARCH_FC_CLUSTER: &str = "arch";
|
||||
|
|
|
|||
|
|
@ -10,6 +10,6 @@ pub mod ratelimit;
|
|||
pub mod routing;
|
||||
pub mod stats;
|
||||
pub mod tokenizer;
|
||||
pub mod tracing;
|
||||
pub mod traces;
|
||||
pub mod tracing;
|
||||
pub mod utils;
|
||||
|
|
|
|||
|
|
@ -41,7 +41,8 @@ pub fn get_llm_provider(
|
|||
llm_providers
|
||||
.iter()
|
||||
.filter(|(_, provider)| {
|
||||
provider.model
|
||||
provider
|
||||
.model
|
||||
.as_ref()
|
||||
.map(|m| !m.starts_with("Arch"))
|
||||
.unwrap_or(true)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
use super::shapes::Span;
|
||||
use super::resource_span_builder::ResourceSpanBuilder;
|
||||
use super::shapes::Span;
|
||||
use std::collections::{HashMap, VecDeque};
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::Mutex;
|
||||
|
|
@ -160,7 +160,11 @@ impl TraceCollector {
|
|||
}
|
||||
|
||||
let total_spans: usize = service_batches.iter().map(|(_, spans)| spans.len()).sum();
|
||||
debug!("Flushing {} spans across {} services to OTEL collector", total_spans, service_batches.len());
|
||||
debug!(
|
||||
"Flushing {} spans across {} services to OTEL collector",
|
||||
total_spans,
|
||||
service_batches.len()
|
||||
);
|
||||
|
||||
// Build canonical OTEL payload structure - one ResourceSpan per service
|
||||
let resource_spans = self.build_resource_spans(service_batches);
|
||||
|
|
@ -178,7 +182,10 @@ impl TraceCollector {
|
|||
}
|
||||
|
||||
/// Build OTEL-compliant resource spans from collected spans, one ResourceSpan per service
|
||||
fn build_resource_spans(&self, service_batches: Vec<(String, Vec<Span>)>) -> Vec<super::shapes::ResourceSpan> {
|
||||
fn build_resource_spans(
|
||||
&self,
|
||||
service_batches: Vec<(String, Vec<Span>)>,
|
||||
) -> Vec<super::shapes::ResourceSpan> {
|
||||
service_batches
|
||||
.into_iter()
|
||||
.map(|(service_name, spans)| {
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
/// OpenTelemetry semantic convention constants for tracing
|
||||
///
|
||||
/// These constants ensure consistency across the codebase and prevent typos
|
||||
|
||||
/// Resource attribute keys following OTEL semantic conventions
|
||||
pub mod resource {
|
||||
/// Logical name of the service
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
// Original tracing types (OTEL structures)
|
||||
mod shapes;
|
||||
// New tracing utilities
|
||||
mod span_builder;
|
||||
mod resource_span_builder;
|
||||
mod constants;
|
||||
mod resource_span_builder;
|
||||
mod span_builder;
|
||||
|
||||
#[cfg(feature = "trace-collection")]
|
||||
mod collector;
|
||||
|
|
@ -13,14 +13,14 @@ mod tests;
|
|||
|
||||
// Re-export original types
|
||||
pub use shapes::{
|
||||
Span, Event, Traceparent, TraceparentNewError,
|
||||
ResourceSpan, Resource, ScopeSpan, Scope, Attribute, AttributeValue,
|
||||
Attribute, AttributeValue, Event, Resource, ResourceSpan, Scope, ScopeSpan, Span, Traceparent,
|
||||
TraceparentNewError,
|
||||
};
|
||||
|
||||
// Re-export new utilities
|
||||
pub use span_builder::{SpanBuilder, SpanKind, generate_random_span_id};
|
||||
pub use resource_span_builder::ResourceSpanBuilder;
|
||||
pub use constants::*;
|
||||
pub use resource_span_builder::ResourceSpanBuilder;
|
||||
pub use span_builder::{generate_random_span_id, SpanBuilder, SpanKind};
|
||||
|
||||
#[cfg(feature = "trace-collection")]
|
||||
pub use collector::{TraceCollector, parse_traceparent};
|
||||
pub use collector::{parse_traceparent, TraceCollector};
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
use super::shapes::{ResourceSpan, Resource, ScopeSpan, Scope, Span, Attribute, AttributeValue};
|
||||
use super::constants::{resource, scope};
|
||||
use super::shapes::{Attribute, AttributeValue, Resource, ResourceSpan, Scope, ScopeSpan, Span};
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Builder for creating OTEL ResourceSpan structures
|
||||
|
|
@ -26,7 +26,11 @@ impl ResourceSpanBuilder {
|
|||
}
|
||||
|
||||
/// Add a resource attribute (e.g., deployment.environment, host.name)
|
||||
pub fn with_resource_attribute(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
|
||||
pub fn with_resource_attribute(
|
||||
mut self,
|
||||
key: impl Into<String>,
|
||||
value: impl Into<String>,
|
||||
) -> Self {
|
||||
self.resource_attributes.insert(key.into(), value.into());
|
||||
self
|
||||
}
|
||||
|
|
@ -58,14 +62,12 @@ impl ResourceSpanBuilder {
|
|||
/// Build the ResourceSpan
|
||||
pub fn build(self) -> ResourceSpan {
|
||||
// Build resource attributes
|
||||
let mut attributes = vec![
|
||||
Attribute {
|
||||
key: resource::SERVICE_NAME.to_string(),
|
||||
value: AttributeValue {
|
||||
string_value: Some(self.service_name),
|
||||
},
|
||||
}
|
||||
];
|
||||
let mut attributes = vec![Attribute {
|
||||
key: resource::SERVICE_NAME.to_string(),
|
||||
value: AttributeValue {
|
||||
string_value: Some(self.service_name),
|
||||
},
|
||||
}];
|
||||
|
||||
// Add custom resource attributes
|
||||
for (key, value) in self.resource_attributes {
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
use super::shapes::{Span, Attribute, AttributeValue};
|
||||
use super::shapes::{Attribute, AttributeValue, Span};
|
||||
use std::collections::HashMap;
|
||||
use std::time::SystemTime;
|
||||
|
||||
|
|
@ -116,10 +116,11 @@ impl SpanBuilder {
|
|||
let end_nanos = system_time_to_nanos(end_time);
|
||||
|
||||
// Generate trace_id if not provided
|
||||
let trace_id = self.trace_id.unwrap_or_else(|| generate_random_trace_id());
|
||||
let trace_id = self.trace_id.unwrap_or_else(generate_random_trace_id);
|
||||
|
||||
// Create attributes in OTEL format
|
||||
let attributes: Vec<Attribute> = self.attributes
|
||||
let attributes: Vec<Attribute> = self
|
||||
.attributes
|
||||
.into_iter()
|
||||
.map(|(key, value)| Attribute {
|
||||
key,
|
||||
|
|
@ -132,7 +133,7 @@ impl SpanBuilder {
|
|||
// Build span directly without going through Span::new()
|
||||
Span {
|
||||
trace_id,
|
||||
span_id: self.span_id.unwrap_or_else(|| generate_random_span_id()),
|
||||
span_id: self.span_id.unwrap_or_else(generate_random_span_id),
|
||||
parent_span_id: self.parent_span_id,
|
||||
name: self.name,
|
||||
start_time_unix_nano: format!("{}", start_nanos),
|
||||
|
|
|
|||
|
|
@ -21,10 +21,7 @@ use tokio::sync::RwLock;
|
|||
type SharedTraces = Arc<RwLock<Vec<Value>>>;
|
||||
|
||||
/// POST /v1/traces - capture incoming OTLP payload
|
||||
async fn post_traces(
|
||||
State(traces): State<SharedTraces>,
|
||||
Json(payload): Json<Value>,
|
||||
) -> StatusCode {
|
||||
async fn post_traces(State(traces): State<SharedTraces>, Json(payload): Json<Value>) -> StatusCode {
|
||||
traces.write().await.push(payload);
|
||||
StatusCode::OK
|
||||
}
|
||||
|
|
@ -67,9 +64,7 @@ impl MockOtelCollector {
|
|||
let address = format!("http://127.0.0.1:{}", addr.port());
|
||||
|
||||
let server_handle = tokio::spawn(async move {
|
||||
axum::serve(listener, app)
|
||||
.await
|
||||
.expect("Server failed");
|
||||
axum::serve(listener, app).await.expect("Server failed");
|
||||
});
|
||||
|
||||
// Give server a moment to start
|
||||
|
|
|
|||
|
|
@ -36,9 +36,12 @@ fn extract_spans(payloads: &[Value]) -> Vec<&Value> {
|
|||
for payload in payloads {
|
||||
if let Some(resource_spans) = payload.get("resourceSpans").and_then(|v| v.as_array()) {
|
||||
for resource_span in resource_spans {
|
||||
if let Some(scope_spans) = resource_span.get("scopeSpans").and_then(|v| v.as_array()) {
|
||||
if let Some(scope_spans) =
|
||||
resource_span.get("scopeSpans").and_then(|v| v.as_array())
|
||||
{
|
||||
for scope_span in scope_spans {
|
||||
if let Some(span_list) = scope_span.get("spans").and_then(|v| v.as_array()) {
|
||||
if let Some(span_list) = scope_span.get("spans").and_then(|v| v.as_array())
|
||||
{
|
||||
spans.extend(span_list.iter());
|
||||
}
|
||||
}
|
||||
|
|
@ -54,9 +57,9 @@ fn get_string_attr<'a>(span: &'a Value, key: &str) -> Option<&'a str> {
|
|||
span.get("attributes")
|
||||
.and_then(|attrs| attrs.as_array())
|
||||
.and_then(|attrs| {
|
||||
attrs.iter().find(|attr| {
|
||||
attr.get("key").and_then(|k| k.as_str()) == Some(key)
|
||||
})
|
||||
attrs
|
||||
.iter()
|
||||
.find(|attr| attr.get("key").and_then(|k| k.as_str()) == Some(key))
|
||||
})
|
||||
.and_then(|attr| attr.get("value"))
|
||||
.and_then(|v| v.get("stringValue"))
|
||||
|
|
@ -70,7 +73,10 @@ async fn test_llm_span_contains_basic_attributes() {
|
|||
let mock_collector = MockOtelCollector::start().await;
|
||||
|
||||
// Create TraceCollector pointing to mock with 500ms flush intervalc
|
||||
std::env::set_var("OTEL_COLLECTOR_URL", format!("{}/v1/traces", mock_collector.address()));
|
||||
std::env::set_var(
|
||||
"OTEL_COLLECTOR_URL",
|
||||
format!("{}/v1/traces", mock_collector.address()),
|
||||
);
|
||||
std::env::set_var("OTEL_TRACING_ENABLED", "true");
|
||||
std::env::set_var("TRACE_FLUSH_INTERVAL_MS", "500");
|
||||
let trace_collector = Arc::new(TraceCollector::new(Some(true)));
|
||||
|
|
@ -102,7 +108,10 @@ async fn test_llm_span_contains_basic_attributes() {
|
|||
let span = spans[0];
|
||||
// Validate HTTP attributes
|
||||
assert_eq!(get_string_attr(span, "http.method"), Some("POST"));
|
||||
assert_eq!(get_string_attr(span, "http.target"), Some("/v1/chat/completions"));
|
||||
assert_eq!(
|
||||
get_string_attr(span, "http.target"),
|
||||
Some("/v1/chat/completions")
|
||||
);
|
||||
|
||||
// Validate LLM attributes
|
||||
assert_eq!(get_string_attr(span, "llm.model"), Some("gpt-4o"));
|
||||
|
|
@ -115,7 +124,10 @@ async fn test_llm_span_contains_basic_attributes() {
|
|||
#[serial]
|
||||
async fn test_llm_span_contains_tool_information() {
|
||||
let mock_collector = MockOtelCollector::start().await;
|
||||
std::env::set_var("OTEL_COLLECTOR_URL", format!("{}/v1/traces", mock_collector.address()));
|
||||
std::env::set_var(
|
||||
"OTEL_COLLECTOR_URL",
|
||||
format!("{}/v1/traces", mock_collector.address()),
|
||||
);
|
||||
std::env::set_var("OTEL_TRACING_ENABLED", "true");
|
||||
std::env::set_var("TRACE_FLUSH_INTERVAL_MS", "500");
|
||||
let trace_collector = Arc::new(TraceCollector::new(Some(true)));
|
||||
|
|
@ -144,19 +156,26 @@ async fn test_llm_span_contains_tool_information() {
|
|||
assert!(tools.unwrap().contains("get_weather(...)"));
|
||||
assert!(tools.unwrap().contains("search_web(...)"));
|
||||
assert!(tools.unwrap().contains("calculate(...)"));
|
||||
assert!(tools.unwrap().contains('\n'), "Tools should be newline-separated");
|
||||
assert!(
|
||||
tools.unwrap().contains('\n'),
|
||||
"Tools should be newline-separated"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[serial]
|
||||
async fn test_llm_span_contains_user_message_preview() {
|
||||
let mock_collector = MockOtelCollector::start().await;
|
||||
std::env::set_var("OTEL_COLLECTOR_URL", format!("{}/v1/traces", mock_collector.address()));
|
||||
std::env::set_var(
|
||||
"OTEL_COLLECTOR_URL",
|
||||
format!("{}/v1/traces", mock_collector.address()),
|
||||
);
|
||||
std::env::set_var("OTEL_TRACING_ENABLED", "true");
|
||||
std::env::set_var("TRACE_FLUSH_INTERVAL_MS", "500");
|
||||
let trace_collector = Arc::new(TraceCollector::new(Some(true)));
|
||||
|
||||
let long_message = "This is a very long user message that should be truncated to 50 characters in the span";
|
||||
let long_message =
|
||||
"This is a very long user message that should be truncated to 50 characters in the span";
|
||||
let preview = if long_message.len() > 50 {
|
||||
format!("{}...", &long_message[..50])
|
||||
} else {
|
||||
|
|
@ -187,7 +206,10 @@ async fn test_llm_span_contains_user_message_preview() {
|
|||
#[serial]
|
||||
async fn test_llm_span_contains_time_to_first_token() {
|
||||
let mock_collector = MockOtelCollector::start().await;
|
||||
std::env::set_var("OTEL_COLLECTOR_URL", format!("{}/v1/traces", mock_collector.address()));
|
||||
std::env::set_var(
|
||||
"OTEL_COLLECTOR_URL",
|
||||
format!("{}/v1/traces", mock_collector.address()),
|
||||
);
|
||||
std::env::set_var("OTEL_TRACING_ENABLED", "true");
|
||||
std::env::set_var("TRACE_FLUSH_INTERVAL_MS", "500");
|
||||
let trace_collector = Arc::new(TraceCollector::new(Some(true)));
|
||||
|
|
@ -217,7 +239,10 @@ async fn test_llm_span_contains_time_to_first_token() {
|
|||
#[serial]
|
||||
async fn test_llm_span_contains_upstream_path() {
|
||||
let mock_collector = MockOtelCollector::start().await;
|
||||
std::env::set_var("OTEL_COLLECTOR_URL", format!("{}/v1/traces", mock_collector.address()));
|
||||
std::env::set_var(
|
||||
"OTEL_COLLECTOR_URL",
|
||||
format!("{}/v1/traces", mock_collector.address()),
|
||||
);
|
||||
std::env::set_var("OTEL_TRACING_ENABLED", "true");
|
||||
std::env::set_var("TRACE_FLUSH_INTERVAL_MS", "500");
|
||||
let trace_collector = Arc::new(TraceCollector::new(Some(true)));
|
||||
|
|
@ -241,7 +266,10 @@ async fn test_llm_span_contains_upstream_path() {
|
|||
// Operation name should show the transformation
|
||||
let name = span.get("name").and_then(|v| v.as_str());
|
||||
assert!(name.is_some());
|
||||
assert!(name.unwrap().contains(">>"), "Operation name should show path transformation");
|
||||
assert!(
|
||||
name.unwrap().contains(">>"),
|
||||
"Operation name should show path transformation"
|
||||
);
|
||||
|
||||
// Check upstream target attribute
|
||||
let upstream = get_string_attr(span, "http.upstream_target");
|
||||
|
|
@ -252,7 +280,10 @@ async fn test_llm_span_contains_upstream_path() {
|
|||
#[serial]
|
||||
async fn test_llm_span_multiple_services() {
|
||||
let mock_collector = MockOtelCollector::start().await;
|
||||
std::env::set_var("OTEL_COLLECTOR_URL", format!("{}/v1/traces", mock_collector.address()));
|
||||
std::env::set_var(
|
||||
"OTEL_COLLECTOR_URL",
|
||||
format!("{}/v1/traces", mock_collector.address()),
|
||||
);
|
||||
std::env::set_var("OTEL_TRACING_ENABLED", "true");
|
||||
std::env::set_var("TRACE_FLUSH_INTERVAL_MS", "500");
|
||||
let trace_collector = Arc::new(TraceCollector::new(Some(true)));
|
||||
|
|
@ -285,7 +316,10 @@ async fn test_tracing_disabled_produces_no_spans() {
|
|||
let mock_collector = MockOtelCollector::start().await;
|
||||
|
||||
// Create TraceCollector with tracing DISABLED
|
||||
std::env::set_var("OTEL_COLLECTOR_URL", format!("{}/v1/traces", mock_collector.address()));
|
||||
std::env::set_var(
|
||||
"OTEL_COLLECTOR_URL",
|
||||
format!("{}/v1/traces", mock_collector.address()),
|
||||
);
|
||||
std::env::set_var("OTEL_TRACING_ENABLED", "false");
|
||||
std::env::set_var("TRACE_FLUSH_INTERVAL_MS", "500");
|
||||
let trace_collector = Arc::new(TraceCollector::new(Some(false)));
|
||||
|
|
@ -300,5 +334,9 @@ async fn test_tracing_disabled_produces_no_spans() {
|
|||
|
||||
let payloads = mock_collector.get_traces().await;
|
||||
let all_spans = extract_spans(&payloads);
|
||||
assert_eq!(all_spans.len(), 0, "No spans should be captured when tracing is disabled");
|
||||
assert_eq!(
|
||||
all_spans.len(),
|
||||
0,
|
||||
"No spans should be captured when tracing is disabled"
|
||||
);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -161,13 +161,12 @@ impl TraceData {
|
|||
}
|
||||
|
||||
pub fn new_with_service_name(service_name: String) -> Self {
|
||||
let mut resource_attributes = Vec::new();
|
||||
resource_attributes.push(Attribute {
|
||||
let resource_attributes = vec![Attribute {
|
||||
key: "service.name".to_string(),
|
||||
value: AttributeValue {
|
||||
string_value: Some(service_name),
|
||||
},
|
||||
});
|
||||
}];
|
||||
|
||||
let resource = Resource {
|
||||
attributes: resource_attributes,
|
||||
|
|
@ -194,7 +193,9 @@ impl TraceData {
|
|||
|
||||
pub fn add_span(&mut self, span: Span) {
|
||||
if self.resource_spans.is_empty() {
|
||||
let resource = Resource { attributes: Vec::new() };
|
||||
let resource = Resource {
|
||||
attributes: Vec::new(),
|
||||
};
|
||||
let scope_span = ScopeSpan {
|
||||
scope: Scope {
|
||||
name: "default".to_string(),
|
||||
|
|
|
|||
|
|
@ -66,7 +66,7 @@ impl ApiDefinition for AmazonBedrockApi {
|
|||
|
||||
/// Amazon Bedrock Converse request
|
||||
#[skip_serializing_none]
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, Default)]
|
||||
pub struct ConverseRequest {
|
||||
/// The model ID or ARN to invoke
|
||||
pub model_id: String,
|
||||
|
|
@ -91,7 +91,7 @@ pub struct ConverseRequest {
|
|||
pub additional_model_response_field_paths: Option<Vec<String>>,
|
||||
/// Performance configuration
|
||||
#[serde(rename = "performanceConfig")]
|
||||
pub performance_config: Option<PerformanceConfiguration>,
|
||||
pub performance_config: Option<InferenceConfiguration>,
|
||||
/// Prompt variables for Prompt management
|
||||
#[serde(rename = "promptVariables")]
|
||||
pub prompt_variables: Option<HashMap<String, PromptVariableValues>>,
|
||||
|
|
@ -105,26 +105,6 @@ pub struct ConverseRequest {
|
|||
pub stream: bool,
|
||||
}
|
||||
|
||||
impl Default for ConverseRequest {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
model_id: String::new(),
|
||||
messages: None,
|
||||
system: None,
|
||||
inference_config: None,
|
||||
tool_config: None,
|
||||
guardrail_config: None,
|
||||
additional_model_request_fields: None,
|
||||
additional_model_response_field_paths: None,
|
||||
performance_config: None,
|
||||
prompt_variables: None,
|
||||
request_metadata: None,
|
||||
metadata: None,
|
||||
stream: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Amazon Bedrock ConverseStream request (same structure as Converse)
|
||||
pub type ConverseStreamRequest = ConverseRequest;
|
||||
|
||||
|
|
@ -204,8 +184,8 @@ impl ProviderRequest for ConverseRequest {
|
|||
self.tool_config.as_ref()?.tools.as_ref().map(|tools| {
|
||||
tools
|
||||
.iter()
|
||||
.filter_map(|tool| match tool {
|
||||
Tool::ToolSpec { tool_spec } => Some(tool_spec.name.clone()),
|
||||
.map(|tool| match tool {
|
||||
Tool::ToolSpec { tool_spec } => tool_spec.name.clone(),
|
||||
})
|
||||
.collect()
|
||||
})
|
||||
|
|
@ -242,17 +222,14 @@ impl ProviderRequest for ConverseRequest {
|
|||
// Add system messages if present
|
||||
if let Some(system) = &self.system {
|
||||
for sys_block in system {
|
||||
match sys_block {
|
||||
SystemContentBlock::Text { text } => {
|
||||
openai_messages.push(Message {
|
||||
role: Role::System,
|
||||
content: MessageContent::Text(text.clone()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
});
|
||||
}
|
||||
_ => {} // Skip other system content types
|
||||
if let SystemContentBlock::Text { text } = sys_block {
|
||||
openai_messages.push(Message {
|
||||
role: Role::System,
|
||||
content: MessageContent::Text(text.clone()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -266,7 +243,9 @@ impl ProviderRequest for ConverseRequest {
|
|||
};
|
||||
|
||||
// Extract text from content blocks
|
||||
let content = msg.content.iter()
|
||||
let content = msg
|
||||
.content
|
||||
.iter()
|
||||
.filter_map(|block| {
|
||||
if let ContentBlock::Text { text } = block {
|
||||
Some(text.clone())
|
||||
|
|
@ -311,16 +290,14 @@ impl ProviderRequest for ConverseRequest {
|
|||
_ => continue,
|
||||
};
|
||||
|
||||
let content = if let crate::apis::openai::MessageContent::Text(text) = &msg.content {
|
||||
vec![ContentBlock::Text { text: text.clone() }]
|
||||
} else {
|
||||
vec![]
|
||||
};
|
||||
let content =
|
||||
if let crate::apis::openai::MessageContent::Text(text) = &msg.content {
|
||||
vec![ContentBlock::Text { text: text.clone() }]
|
||||
} else {
|
||||
vec![]
|
||||
};
|
||||
|
||||
bedrock_messages.push(crate::apis::amazon_bedrock::Message {
|
||||
role,
|
||||
content,
|
||||
});
|
||||
bedrock_messages.push(crate::apis::amazon_bedrock::Message { role, content });
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
|
@ -369,7 +346,7 @@ pub enum ConverseStreamEvent {
|
|||
ContentBlockDelta(ContentBlockDeltaEvent),
|
||||
ContentBlockStop(ContentBlockStopEvent),
|
||||
MessageStop(MessageStopEvent),
|
||||
Metadata(ConverseStreamMetadataEvent),
|
||||
Metadata(Box<ConverseStreamMetadataEvent>),
|
||||
// Error events
|
||||
InternalServerException(BedrockException),
|
||||
ModelStreamErrorException(BedrockException),
|
||||
|
|
@ -1063,7 +1040,7 @@ impl TryFrom<&aws_smithy_eventstream::frame::DecodedFrame> for ConverseStreamEve
|
|||
"metadata" => {
|
||||
let event: ConverseStreamMetadataEvent =
|
||||
serde_json::from_slice(payload).map_err(BedrockError::Serialization)?;
|
||||
Ok(ConverseStreamEvent::Metadata(event))
|
||||
Ok(ConverseStreamEvent::Metadata(Box::new(event)))
|
||||
}
|
||||
unknown => Err(BedrockError::Validation {
|
||||
message: format!("Unknown event type: {}", unknown),
|
||||
|
|
@ -1106,10 +1083,10 @@ impl TryFrom<&aws_smithy_eventstream::frame::DecodedFrame> for ConverseStreamEve
|
|||
}
|
||||
}
|
||||
|
||||
impl Into<String> for ConverseStreamEvent {
|
||||
fn into(self) -> String {
|
||||
let transformed_json = serde_json::to_string(&self).unwrap_or_default();
|
||||
let event_type = match &self {
|
||||
impl From<ConverseStreamEvent> for String {
|
||||
fn from(val: ConverseStreamEvent) -> String {
|
||||
let transformed_json = serde_json::to_string(&val).unwrap_or_default();
|
||||
let event_type = match &val {
|
||||
ConverseStreamEvent::MessageStart { .. } => "message_start",
|
||||
ConverseStreamEvent::ContentBlockStart { .. } => "content_block_start",
|
||||
ConverseStreamEvent::ContentBlockDelta { .. } => "content_block_delta",
|
||||
|
|
|
|||
File diff suppressed because one or more lines are too long
|
|
@ -286,7 +286,6 @@ pub struct ImageUrl {
|
|||
}
|
||||
|
||||
/// A single message in a chat conversation
|
||||
|
||||
/// A tool call made by the assistant
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
|
||||
pub struct ToolCall {
|
||||
|
|
@ -388,7 +387,7 @@ pub enum StaticContentType {
|
|||
|
||||
/// Chat completions API response
|
||||
#[skip_serializing_none]
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, Default)]
|
||||
pub struct ChatCompletionsResponse {
|
||||
pub id: String,
|
||||
pub object: Option<String>,
|
||||
|
|
@ -402,22 +401,6 @@ pub struct ChatCompletionsResponse {
|
|||
pub metadata: Option<HashMap<String, Value>>,
|
||||
}
|
||||
|
||||
impl Default for ChatCompletionsResponse {
|
||||
fn default() -> Self {
|
||||
ChatCompletionsResponse {
|
||||
id: String::new(),
|
||||
object: None,
|
||||
created: 0,
|
||||
model: String::new(),
|
||||
choices: vec![],
|
||||
usage: Usage::default(),
|
||||
system_fingerprint: None,
|
||||
service_tier: None,
|
||||
metadata: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Finish reason for completion
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
|
|
@ -431,7 +414,7 @@ pub enum FinishReason {
|
|||
|
||||
/// Token usage information
|
||||
#[skip_serializing_none]
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, Default)]
|
||||
pub struct Usage {
|
||||
pub prompt_tokens: u32,
|
||||
pub completion_tokens: u32,
|
||||
|
|
@ -440,18 +423,6 @@ pub struct Usage {
|
|||
pub completion_tokens_details: Option<CompletionTokensDetails>,
|
||||
}
|
||||
|
||||
impl Default for Usage {
|
||||
fn default() -> Self {
|
||||
Usage {
|
||||
prompt_tokens: 0,
|
||||
completion_tokens: 0,
|
||||
total_tokens: 0,
|
||||
prompt_tokens_details: None,
|
||||
completion_tokens_details: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Detailed breakdown of prompt tokens
|
||||
#[skip_serializing_none]
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
|
|
@ -472,7 +443,7 @@ pub struct CompletionTokensDetails {
|
|||
|
||||
/// A single choice in the response
|
||||
#[skip_serializing_none]
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, Default)]
|
||||
pub struct Choice {
|
||||
pub index: u32,
|
||||
pub message: ResponseMessage,
|
||||
|
|
@ -480,17 +451,6 @@ pub struct Choice {
|
|||
pub logprobs: Option<Value>,
|
||||
}
|
||||
|
||||
impl Default for Choice {
|
||||
fn default() -> Self {
|
||||
Choice {
|
||||
index: 0,
|
||||
message: ResponseMessage::default(),
|
||||
finish_reason: None,
|
||||
logprobs: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// STREAMING API TYPES
|
||||
// ============================================================================
|
||||
|
|
@ -608,7 +568,6 @@ pub enum OpenAIError {
|
|||
// ============================================================================
|
||||
/// Trait Implementations
|
||||
/// ===========================================================================
|
||||
|
||||
/// Parameterized conversion for ChatCompletionsRequest
|
||||
impl TryFrom<&[u8]> for ChatCompletionsRequest {
|
||||
type Error = OpenAIStreamError;
|
||||
|
|
@ -721,7 +680,7 @@ impl ProviderRequest for ChatCompletionsRequest {
|
|||
}
|
||||
|
||||
fn metadata(&self) -> &Option<HashMap<String, Value>> {
|
||||
return &self.metadata;
|
||||
&self.metadata
|
||||
}
|
||||
|
||||
fn remove_metadata_key(&mut self, key: &str) -> bool {
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
use std::collections::HashMap;
|
||||
use crate::providers::request::{ProviderRequest, ProviderRequestError};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_with::skip_serializing_none;
|
||||
use crate::providers::request::{ProviderRequest, ProviderRequestError};
|
||||
use std::collections::HashMap;
|
||||
|
||||
impl TryFrom<&[u8]> for ResponsesAPIRequest {
|
||||
type Error = serde_json::Error;
|
||||
|
|
@ -172,18 +172,14 @@ pub enum MessageRole {
|
|||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum InputContent {
|
||||
/// Text input
|
||||
InputText {
|
||||
text: String,
|
||||
},
|
||||
InputText { text: String },
|
||||
/// Image input via URL
|
||||
InputImage {
|
||||
image_url: String,
|
||||
detail: Option<String>,
|
||||
},
|
||||
/// File input via URL
|
||||
InputFile {
|
||||
file_url: String,
|
||||
},
|
||||
InputFile { file_url: String },
|
||||
/// Audio input
|
||||
InputAudio {
|
||||
data: Option<String>,
|
||||
|
|
@ -222,9 +218,7 @@ pub struct TextConfig {
|
|||
pub enum TextFormat {
|
||||
Text,
|
||||
JsonObject,
|
||||
JsonSchema {
|
||||
json_schema: serde_json::Value,
|
||||
},
|
||||
JsonSchema { json_schema: serde_json::Value },
|
||||
}
|
||||
|
||||
/// Reasoning effort levels
|
||||
|
|
@ -608,9 +602,7 @@ pub enum OutputContent {
|
|||
transcript: Option<String>,
|
||||
},
|
||||
/// Refusal output
|
||||
Refusal {
|
||||
refusal: String,
|
||||
},
|
||||
Refusal { refusal: String },
|
||||
}
|
||||
|
||||
/// Annotations for output text
|
||||
|
|
@ -663,13 +655,9 @@ pub struct FileSearchResult {
|
|||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum CodeInterpreterOutput {
|
||||
/// Text output
|
||||
Text {
|
||||
text: String,
|
||||
},
|
||||
Text { text: String },
|
||||
/// Image output
|
||||
Image {
|
||||
image: String,
|
||||
},
|
||||
Image { image: String },
|
||||
}
|
||||
|
||||
/// Response usage statistics
|
||||
|
|
@ -951,9 +939,7 @@ pub enum ResponsesAPIStreamEvent {
|
|||
},
|
||||
|
||||
/// Done event (end of stream)
|
||||
Done {
|
||||
sequence_number: i32,
|
||||
},
|
||||
Done { sequence_number: i32 },
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
|
|
@ -1052,12 +1038,19 @@ impl ProviderRequest for ResponsesAPIRequest {
|
|||
MessageContent::Text(text) => text.clone(),
|
||||
MessageContent::Items(content_items) => {
|
||||
content_items.iter().fold(String::new(), |acc, content| {
|
||||
acc + " " + &match content {
|
||||
InputContent::InputText { text } => text.clone(),
|
||||
InputContent::InputImage { .. } => "[Image]".to_string(),
|
||||
InputContent::InputFile { .. } => "[File]".to_string(),
|
||||
InputContent::InputAudio { .. } => "[Audio]".to_string(),
|
||||
}
|
||||
acc + " "
|
||||
+ &match content {
|
||||
InputContent::InputText { text } => text.clone(),
|
||||
InputContent::InputImage { .. } => {
|
||||
"[Image]".to_string()
|
||||
}
|
||||
InputContent::InputFile { .. } => {
|
||||
"[File]".to_string()
|
||||
}
|
||||
InputContent::InputAudio { .. } => {
|
||||
"[Audio]".to_string()
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
};
|
||||
|
|
@ -1082,11 +1075,9 @@ impl ProviderRequest for ResponsesAPIRequest {
|
|||
match &msg.content {
|
||||
MessageContent::Text(text) => Some(text.clone()),
|
||||
MessageContent::Items(content_items) => {
|
||||
content_items.iter().find_map(|content| {
|
||||
match content {
|
||||
InputContent::InputText { text } => Some(text.clone()),
|
||||
_ => None,
|
||||
}
|
||||
content_items.iter().find_map(|content| match content {
|
||||
InputContent::InputText { text } => Some(text.clone()),
|
||||
_ => None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
@ -1176,9 +1167,12 @@ impl ProviderRequest for ResponsesAPIRequest {
|
|||
|
||||
// Extract text from message content
|
||||
let content = match &msg.content {
|
||||
crate::apis::openai_responses::MessageContent::Text(text) => text.clone(),
|
||||
crate::apis::openai_responses::MessageContent::Text(text) => {
|
||||
text.clone()
|
||||
}
|
||||
crate::apis::openai_responses::MessageContent::Items(items) => {
|
||||
items.iter()
|
||||
items
|
||||
.iter()
|
||||
.filter_map(|c| {
|
||||
if let InputContent::InputText { text } = c {
|
||||
Some(text.clone())
|
||||
|
|
@ -1214,7 +1208,8 @@ impl ProviderRequest for ResponsesAPIRequest {
|
|||
fn set_messages(&mut self, messages: &[crate::apis::openai::Message]) {
|
||||
// For ResponsesAPI, we need to convert messages back to input format
|
||||
// Extract system messages as instructions
|
||||
let system_text = messages.iter()
|
||||
let system_text = messages
|
||||
.iter()
|
||||
.filter(|msg| msg.role == crate::apis::openai::Role::System)
|
||||
.filter_map(|msg| {
|
||||
if let crate::apis::openai::MessageContent::Text(text) = &msg.content {
|
||||
|
|
@ -1233,23 +1228,27 @@ impl ProviderRequest for ResponsesAPIRequest {
|
|||
// Convert user/assistant messages to InputParam
|
||||
// For simplicity, we'll use the last user message as the input
|
||||
// or combine all non-system messages
|
||||
let input_messages: Vec<_> = messages.iter()
|
||||
let input_messages: Vec<_> = messages
|
||||
.iter()
|
||||
.filter(|msg| msg.role != crate::apis::openai::Role::System)
|
||||
.collect();
|
||||
|
||||
if !input_messages.is_empty() {
|
||||
// If there's only one message, use Text format
|
||||
if input_messages.len() == 1 {
|
||||
if let crate::apis::openai::MessageContent::Text(text) = &input_messages[0].content {
|
||||
if let crate::apis::openai::MessageContent::Text(text) = &input_messages[0].content
|
||||
{
|
||||
self.input = crate::apis::openai_responses::InputParam::Text(text.clone());
|
||||
}
|
||||
} else {
|
||||
// Multiple messages - combine them as text for now
|
||||
// A more sophisticated approach would use InputParam::Items
|
||||
let combined_text = input_messages.iter()
|
||||
let combined_text = input_messages
|
||||
.iter()
|
||||
.filter_map(|msg| {
|
||||
if let crate::apis::openai::MessageContent::Text(text) = &msg.content {
|
||||
Some(format!("{}: {}",
|
||||
Some(format!(
|
||||
"{}: {}",
|
||||
match msg.role {
|
||||
crate::apis::openai::Role::User => "User",
|
||||
crate::apis::openai::Role::Assistant => "Assistant",
|
||||
|
|
@ -1274,10 +1273,10 @@ impl ProviderRequest for ResponsesAPIRequest {
|
|||
// Into<String> Implementation for SSE Formatting
|
||||
// ============================================================================
|
||||
|
||||
impl Into<String> for ResponsesAPIStreamEvent {
|
||||
fn into(self) -> String {
|
||||
let transformed_json = serde_json::to_string(&self).unwrap_or_default();
|
||||
let event_type = match &self {
|
||||
impl From<ResponsesAPIStreamEvent> for String {
|
||||
fn from(val: ResponsesAPIStreamEvent) -> Self {
|
||||
let transformed_json = serde_json::to_string(&val).unwrap_or_default();
|
||||
let event_type = match &val {
|
||||
ResponsesAPIStreamEvent::ResponseCreated { .. } => "response.created",
|
||||
ResponsesAPIStreamEvent::ResponseInProgress { .. } => "response.in_progress",
|
||||
ResponsesAPIStreamEvent::ResponseCompleted { .. } => "response.completed",
|
||||
|
|
@ -1365,10 +1364,10 @@ impl crate::providers::streaming_response::ProviderStreamResponse for ResponsesA
|
|||
|
||||
fn role(&self) -> Option<&str> {
|
||||
match self {
|
||||
ResponsesAPIStreamEvent::ResponseOutputItemDone { item, .. } => match item {
|
||||
OutputItem::Message { role, .. } => Some(role.as_str()),
|
||||
_ => None,
|
||||
},
|
||||
ResponsesAPIStreamEvent::ResponseOutputItemDone {
|
||||
item: OutputItem::Message { role, .. },
|
||||
..
|
||||
} => Some(role.as_str()),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -34,10 +34,7 @@ where
|
|||
}
|
||||
|
||||
pub fn decode_frame(&mut self) -> Option<DecodedFrame> {
|
||||
match self.decoder.decode_frame(&mut self.buffer) {
|
||||
Ok(frame) => Some(frame),
|
||||
Err(_e) => None, // Fatal decode error
|
||||
}
|
||||
self.decoder.decode_frame(&mut self.buffer).ok()
|
||||
}
|
||||
|
||||
pub fn buffer_mut(&mut self) -> &mut B {
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
use crate::apis::streaming_shapes::sse::{SseEvent, SseStreamBufferTrait};
|
||||
use crate::apis::anthropic::MessagesStreamEvent;
|
||||
use crate::apis::streaming_shapes::sse::{SseEvent, SseStreamBufferTrait};
|
||||
use crate::providers::streaming_response::ProviderStreamResponseType;
|
||||
use std::collections::HashSet;
|
||||
|
||||
|
|
@ -31,6 +31,12 @@ pub struct AnthropicMessagesStreamBuffer {
|
|||
model: Option<String>,
|
||||
}
|
||||
|
||||
impl Default for AnthropicMessagesStreamBuffer {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl AnthropicMessagesStreamBuffer {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
|
|
@ -154,7 +160,8 @@ impl SseStreamBufferTrait for AnthropicMessagesStreamBuffer {
|
|||
// Inject message_start if needed
|
||||
if !self.message_started {
|
||||
let model = self.model.as_deref().unwrap_or("unknown");
|
||||
let message_start = AnthropicMessagesStreamBuffer::create_message_start_event(model);
|
||||
let message_start =
|
||||
AnthropicMessagesStreamBuffer::create_message_start_event(model);
|
||||
self.buffered_events.push(message_start);
|
||||
self.message_started = true;
|
||||
}
|
||||
|
|
@ -169,7 +176,8 @@ impl SseStreamBufferTrait for AnthropicMessagesStreamBuffer {
|
|||
// Inject message_start if needed
|
||||
if !self.message_started {
|
||||
let model = self.model.as_deref().unwrap_or("unknown");
|
||||
let message_start = AnthropicMessagesStreamBuffer::create_message_start_event(model);
|
||||
let message_start =
|
||||
AnthropicMessagesStreamBuffer::create_message_start_event(model);
|
||||
self.buffered_events.push(message_start);
|
||||
self.message_started = true;
|
||||
}
|
||||
|
|
@ -177,7 +185,8 @@ impl SseStreamBufferTrait for AnthropicMessagesStreamBuffer {
|
|||
// Check if ContentBlockStart was sent for this index
|
||||
if !self.has_content_block_start_been_sent(index) {
|
||||
// Inject ContentBlockStart before delta
|
||||
let content_block_start = AnthropicMessagesStreamBuffer::create_content_block_start_event();
|
||||
let content_block_start =
|
||||
AnthropicMessagesStreamBuffer::create_content_block_start_event();
|
||||
self.buffered_events.push(content_block_start);
|
||||
self.set_content_block_start_sent(index);
|
||||
self.needs_content_block_stop = true;
|
||||
|
|
@ -189,7 +198,8 @@ impl SseStreamBufferTrait for AnthropicMessagesStreamBuffer {
|
|||
MessagesStreamEvent::MessageDelta { usage, .. } => {
|
||||
// Inject ContentBlockStop before message_delta
|
||||
if self.needs_content_block_stop {
|
||||
let content_block_stop = AnthropicMessagesStreamBuffer::create_content_block_stop_event();
|
||||
let content_block_stop =
|
||||
AnthropicMessagesStreamBuffer::create_content_block_stop_event();
|
||||
self.buffered_events.push(content_block_stop);
|
||||
self.needs_content_block_stop = false;
|
||||
}
|
||||
|
|
@ -199,10 +209,10 @@ impl SseStreamBufferTrait for AnthropicMessagesStreamBuffer {
|
|||
if let Some(last_event) = self.buffered_events.last_mut() {
|
||||
if let Some(ProviderStreamResponseType::MessagesStreamEvent(
|
||||
MessagesStreamEvent::MessageDelta {
|
||||
usage: last_usage,
|
||||
..
|
||||
}
|
||||
)) = &mut last_event.provider_stream_response {
|
||||
usage: last_usage, ..
|
||||
},
|
||||
)) = &mut last_event.provider_stream_response
|
||||
{
|
||||
// Merge: take stop_reason from first, usage from second (if non-zero)
|
||||
if usage.input_tokens > 0 || usage.output_tokens > 0 {
|
||||
*last_usage = usage.clone();
|
||||
|
|
@ -243,7 +253,7 @@ impl SseStreamBufferTrait for AnthropicMessagesStreamBuffer {
|
|||
}
|
||||
}
|
||||
|
||||
fn into_bytes(&mut self) -> Vec<u8> {
|
||||
fn to_bytes(&mut self) -> Vec<u8> {
|
||||
// Convert all accumulated events to bytes and clear buffer
|
||||
// NOTE: We do NOT inject ContentBlockStop here because it's injected when we see MessageDelta
|
||||
// or MessageStop. Injecting it here causes premature ContentBlockStop in the middle of streaming.
|
||||
|
|
@ -276,10 +286,10 @@ impl SseStreamBufferTrait for AnthropicMessagesStreamBuffer {
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::clients::{SupportedAPIsFromClient, SupportedUpstreamAPIs};
|
||||
use crate::apis::anthropic::AnthropicApi;
|
||||
use crate::apis::openai::OpenAIApi;
|
||||
use crate::apis::streaming_shapes::sse::SseStreamIter;
|
||||
use crate::clients::{SupportedAPIsFromClient, SupportedUpstreamAPIs};
|
||||
|
||||
#[test]
|
||||
fn test_openai_to_anthropic_complete_transformation() {
|
||||
|
|
@ -308,11 +318,12 @@ data: [DONE]"#;
|
|||
let mut buffer = AnthropicMessagesStreamBuffer::new();
|
||||
|
||||
for raw_event in stream_iter {
|
||||
let transformed_event = SseEvent::try_from((raw_event, &client_api, &upstream_api)).unwrap();
|
||||
let transformed_event =
|
||||
SseEvent::try_from((raw_event, &client_api, &upstream_api)).unwrap();
|
||||
buffer.add_transformed_event(transformed_event);
|
||||
}
|
||||
|
||||
let output_bytes = buffer.into_bytes();
|
||||
let output_bytes = buffer.to_bytes();
|
||||
let output = String::from_utf8_lossy(&output_bytes);
|
||||
|
||||
println!("\nTRANSFORMED OUTPUT (Anthropic Messages API):");
|
||||
|
|
@ -321,25 +332,54 @@ data: [DONE]"#;
|
|||
|
||||
// Assertions
|
||||
assert!(!output_bytes.is_empty(), "Should have output");
|
||||
assert!(output.contains("event: message_start"), "Should have message_start");
|
||||
assert!(output.contains("event: content_block_start"), "Should have content_block_start (injected)");
|
||||
assert!(
|
||||
output.contains("event: message_start"),
|
||||
"Should have message_start"
|
||||
);
|
||||
assert!(
|
||||
output.contains("event: content_block_start"),
|
||||
"Should have content_block_start (injected)"
|
||||
);
|
||||
|
||||
let delta_count = output.matches("event: content_block_delta").count();
|
||||
assert_eq!(delta_count, 2, "Should have exactly 2 content_block_delta events");
|
||||
assert_eq!(
|
||||
delta_count, 2,
|
||||
"Should have exactly 2 content_block_delta events"
|
||||
);
|
||||
|
||||
// Verify both pieces of content are present
|
||||
assert!(output.contains("\"text\":\"Hello\""), "Should have first content delta 'Hello'");
|
||||
assert!(output.contains("\"text\":\" world\""), "Should have second content delta ' world'");
|
||||
assert!(
|
||||
output.contains("\"text\":\"Hello\""),
|
||||
"Should have first content delta 'Hello'"
|
||||
);
|
||||
assert!(
|
||||
output.contains("\"text\":\" world\""),
|
||||
"Should have second content delta ' world'"
|
||||
);
|
||||
|
||||
assert!(output.contains("event: content_block_stop"), "Should have content_block_stop (injected)");
|
||||
assert!(output.contains("event: message_delta"), "Should have message_delta");
|
||||
assert!(output.contains("event: message_stop"), "Should have message_stop");
|
||||
assert!(
|
||||
output.contains("event: content_block_stop"),
|
||||
"Should have content_block_stop (injected)"
|
||||
);
|
||||
assert!(
|
||||
output.contains("event: message_delta"),
|
||||
"Should have message_delta"
|
||||
);
|
||||
assert!(
|
||||
output.contains("event: message_stop"),
|
||||
"Should have message_stop"
|
||||
);
|
||||
|
||||
println!("\nVALIDATION SUMMARY:");
|
||||
println!("{}", "-".repeat(80));
|
||||
println!("✓ Complete transformation: OpenAI ChatCompletions → Anthropic Messages API");
|
||||
println!("✓ Injected lifecycle events: message_start, content_block_start, content_block_stop");
|
||||
println!("✓ Content deltas: {} events (BOTH 'Hello' and ' world' preserved!)", delta_count);
|
||||
println!(
|
||||
"✓ Injected lifecycle events: message_start, content_block_start, content_block_stop"
|
||||
);
|
||||
println!(
|
||||
"✓ Content deltas: {} events (BOTH 'Hello' and ' world' preserved!)",
|
||||
delta_count
|
||||
);
|
||||
println!("✓ Complete stream with message_stop");
|
||||
println!("✓ Proper Anthropic protocol sequencing\n");
|
||||
}
|
||||
|
|
@ -369,11 +409,12 @@ data: {"id":"chatcmpl-456","object":"chat.completion.chunk","created":1234567890
|
|||
let mut buffer = AnthropicMessagesStreamBuffer::new();
|
||||
|
||||
for raw_event in stream_iter {
|
||||
let transformed_event = SseEvent::try_from((raw_event, &client_api, &upstream_api)).unwrap();
|
||||
let transformed_event =
|
||||
SseEvent::try_from((raw_event, &client_api, &upstream_api)).unwrap();
|
||||
buffer.add_transformed_event(transformed_event);
|
||||
}
|
||||
|
||||
let output_bytes = buffer.into_bytes();
|
||||
let output_bytes = buffer.to_bytes();
|
||||
let output = String::from_utf8_lossy(&output_bytes);
|
||||
|
||||
println!("\nTRANSFORMED OUTPUT (Anthropic Messages API):");
|
||||
|
|
@ -382,31 +423,61 @@ data: {"id":"chatcmpl-456","object":"chat.completion.chunk","created":1234567890
|
|||
|
||||
// Assertions
|
||||
assert!(!output_bytes.is_empty(), "Should have output");
|
||||
assert!(output.contains("event: message_start"), "Should have message_start");
|
||||
assert!(output.contains("event: content_block_start"), "Should have content_block_start (injected)");
|
||||
assert!(
|
||||
output.contains("event: message_start"),
|
||||
"Should have message_start"
|
||||
);
|
||||
assert!(
|
||||
output.contains("event: content_block_start"),
|
||||
"Should have content_block_start (injected)"
|
||||
);
|
||||
|
||||
let delta_count = output.matches("event: content_block_delta").count();
|
||||
assert_eq!(delta_count, 3, "Should have exactly 3 content_block_delta events");
|
||||
assert_eq!(
|
||||
delta_count, 3,
|
||||
"Should have exactly 3 content_block_delta events"
|
||||
);
|
||||
|
||||
// Verify all three pieces of content are present
|
||||
assert!(output.contains("\"text\":\"The weather\""), "Should have first content delta");
|
||||
assert!(output.contains("\"text\":\" in San Francisco\""), "Should have second content delta");
|
||||
assert!(output.contains("\"text\":\" is\""), "Should have third content delta");
|
||||
assert!(
|
||||
output.contains("\"text\":\"The weather\""),
|
||||
"Should have first content delta"
|
||||
);
|
||||
assert!(
|
||||
output.contains("\"text\":\" in San Francisco\""),
|
||||
"Should have second content delta"
|
||||
);
|
||||
assert!(
|
||||
output.contains("\"text\":\" is\""),
|
||||
"Should have third content delta"
|
||||
);
|
||||
|
||||
// For partial streams (no finish_reason, no [DONE]), we do NOT inject content_block_stop
|
||||
// because the stream may continue. This is correct behavior - only inject lifecycle events
|
||||
// when we have explicit signals from upstream (finish_reason, [DONE], etc.)
|
||||
assert!(!output.contains("event: content_block_stop"), "Should NOT have content_block_stop for partial stream");
|
||||
assert!(
|
||||
!output.contains("event: content_block_stop"),
|
||||
"Should NOT have content_block_stop for partial stream"
|
||||
);
|
||||
|
||||
// Should NOT have completion events
|
||||
assert!(!output.contains("event: message_delta"), "Should NOT have message_delta");
|
||||
assert!(!output.contains("event: message_stop"), "Should NOT have message_stop");
|
||||
assert!(
|
||||
!output.contains("event: message_delta"),
|
||||
"Should NOT have message_delta"
|
||||
);
|
||||
assert!(
|
||||
!output.contains("event: message_stop"),
|
||||
"Should NOT have message_stop"
|
||||
);
|
||||
|
||||
println!("\nVALIDATION SUMMARY:");
|
||||
println!("{}", "-".repeat(80));
|
||||
println!("✓ Partial transformation: OpenAI → Anthropic (stream interrupted)");
|
||||
println!("✓ Injected: message_start, content_block_start at beginning");
|
||||
println!("✓ Incremental deltas: {} events (ALL content preserved!)", delta_count);
|
||||
println!(
|
||||
"✓ Incremental deltas: {} events (ALL content preserved!)",
|
||||
delta_count
|
||||
);
|
||||
println!("✓ NO completion events (partial stream, no [DONE])");
|
||||
println!("✓ Buffer maintains Anthropic protocol for active streams\n");
|
||||
}
|
||||
|
|
@ -452,11 +523,12 @@ data: [DONE]"#;
|
|||
let mut buffer = AnthropicMessagesStreamBuffer::new();
|
||||
|
||||
for raw_event in stream_iter {
|
||||
let transformed_event = SseEvent::try_from((raw_event, &client_api, &upstream_api)).unwrap();
|
||||
let transformed_event =
|
||||
SseEvent::try_from((raw_event, &client_api, &upstream_api)).unwrap();
|
||||
buffer.add_transformed_event(transformed_event);
|
||||
}
|
||||
|
||||
let output_bytes = buffer.into_bytes();
|
||||
let output_bytes = buffer.to_bytes();
|
||||
let output = String::from_utf8_lossy(&output_bytes);
|
||||
|
||||
println!("\nTRANSFORMED OUTPUT (Anthropic Messages API):");
|
||||
|
|
@ -467,32 +539,71 @@ data: [DONE]"#;
|
|||
assert!(!output_bytes.is_empty(), "Should have output");
|
||||
|
||||
// Should have lifecycle events (injected by buffer)
|
||||
assert!(output.contains("event: message_start"), "Should have message_start (injected)");
|
||||
assert!(output.contains("event: content_block_start"), "Should have content_block_start");
|
||||
assert!(output.contains("event: content_block_stop"), "Should have content_block_stop (injected)");
|
||||
assert!(output.contains("event: message_delta"), "Should have message_delta");
|
||||
assert!(output.contains("event: message_stop"), "Should have message_stop");
|
||||
assert!(
|
||||
output.contains("event: message_start"),
|
||||
"Should have message_start (injected)"
|
||||
);
|
||||
assert!(
|
||||
output.contains("event: content_block_start"),
|
||||
"Should have content_block_start"
|
||||
);
|
||||
assert!(
|
||||
output.contains("event: content_block_stop"),
|
||||
"Should have content_block_stop (injected)"
|
||||
);
|
||||
assert!(
|
||||
output.contains("event: message_delta"),
|
||||
"Should have message_delta"
|
||||
);
|
||||
assert!(
|
||||
output.contains("event: message_stop"),
|
||||
"Should have message_stop"
|
||||
);
|
||||
|
||||
// Should have tool_use content block
|
||||
assert!(output.contains("\"type\":\"tool_use\""), "Should have tool_use type");
|
||||
assert!(output.contains("\"name\":\"get_weather\""), "Should have correct function name");
|
||||
assert!(output.contains("\"id\":\"call_2Uzw0AEZQeOex2CP2TKjcLKc\""), "Should have correct tool call ID");
|
||||
assert!(
|
||||
output.contains("\"type\":\"tool_use\""),
|
||||
"Should have tool_use type"
|
||||
);
|
||||
assert!(
|
||||
output.contains("\"name\":\"get_weather\""),
|
||||
"Should have correct function name"
|
||||
);
|
||||
assert!(
|
||||
output.contains("\"id\":\"call_2Uzw0AEZQeOex2CP2TKjcLKc\""),
|
||||
"Should have correct tool call ID"
|
||||
);
|
||||
|
||||
// Count input_json_delta events - should match the number of argument chunks
|
||||
let delta_count = output.matches("event: content_block_delta").count();
|
||||
assert!(delta_count >= 8, "Should have at least 8 input_json_delta events");
|
||||
assert!(
|
||||
delta_count >= 8,
|
||||
"Should have at least 8 input_json_delta events"
|
||||
);
|
||||
|
||||
// Verify argument deltas are present
|
||||
assert!(output.contains("\"type\":\"input_json_delta\""), "Should have input_json_delta type");
|
||||
assert!(output.contains("\"partial_json\":"), "Should have partial_json field");
|
||||
assert!(
|
||||
output.contains("\"type\":\"input_json_delta\""),
|
||||
"Should have input_json_delta type"
|
||||
);
|
||||
assert!(
|
||||
output.contains("\"partial_json\":"),
|
||||
"Should have partial_json field"
|
||||
);
|
||||
|
||||
// Verify the accumulated arguments contain the location
|
||||
assert!(output.contains("San"), "Arguments should contain 'San'");
|
||||
assert!(output.contains("Francisco"), "Arguments should contain 'Francisco'");
|
||||
assert!(
|
||||
output.contains("Francisco"),
|
||||
"Arguments should contain 'Francisco'"
|
||||
);
|
||||
assert!(output.contains("CA"), "Arguments should contain 'CA'");
|
||||
|
||||
// Verify stop reason is tool_use
|
||||
assert!(output.contains("\"stop_reason\":\"tool_use\""), "Should have stop_reason as tool_use");
|
||||
assert!(
|
||||
output.contains("\"stop_reason\":\"tool_use\""),
|
||||
"Should have stop_reason as tool_use"
|
||||
);
|
||||
|
||||
println!("\nVALIDATION SUMMARY:");
|
||||
println!("{}", "-".repeat(80));
|
||||
|
|
|
|||
|
|
@ -6,6 +6,12 @@ pub struct OpenAIChatCompletionsStreamBuffer {
|
|||
buffered_events: Vec<SseEvent>,
|
||||
}
|
||||
|
||||
impl Default for OpenAIChatCompletionsStreamBuffer {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl OpenAIChatCompletionsStreamBuffer {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
|
|
@ -26,7 +32,7 @@ impl SseStreamBufferTrait for OpenAIChatCompletionsStreamBuffer {
|
|||
self.buffered_events.push(event);
|
||||
}
|
||||
|
||||
fn into_bytes(&mut self) -> Vec<u8> {
|
||||
fn to_bytes(&mut self) -> Vec<u8> {
|
||||
// No finalization needed for OpenAI Chat Completions
|
||||
// The [DONE] marker is already handled by the transformation layer
|
||||
let mut buffer = Vec::new();
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
pub mod sse;
|
||||
pub mod sse_chunk_processor;
|
||||
pub mod amazon_bedrock_binary_frame;
|
||||
pub mod anthropic_streaming_buffer;
|
||||
pub mod chat_completions_streaming_buffer;
|
||||
pub mod passthrough_streaming_buffer;
|
||||
pub mod responses_api_streaming_buffer;
|
||||
pub mod sse;
|
||||
pub mod sse_chunk_processor;
|
||||
|
|
|
|||
|
|
@ -6,6 +6,12 @@ pub struct PassthroughStreamBuffer {
|
|||
buffered_events: Vec<SseEvent>,
|
||||
}
|
||||
|
||||
impl Default for PassthroughStreamBuffer {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl PassthroughStreamBuffer {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
|
|
@ -30,7 +36,7 @@ impl SseStreamBufferTrait for PassthroughStreamBuffer {
|
|||
self.buffered_events.push(event);
|
||||
}
|
||||
|
||||
fn into_bytes(&mut self) -> Vec<u8> {
|
||||
fn to_bytes(&mut self) -> Vec<u8> {
|
||||
// No finalization needed for passthrough - just convert accumulated events to bytes
|
||||
let mut buffer = Vec::new();
|
||||
for event in self.buffered_events.drain(..) {
|
||||
|
|
@ -44,7 +50,7 @@ impl SseStreamBufferTrait for PassthroughStreamBuffer {
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::apis::streaming_shapes::passthrough_streaming_buffer::PassthroughStreamBuffer;
|
||||
use crate::apis::streaming_shapes::sse::{SseStreamIter, SseStreamBufferTrait};
|
||||
use crate::apis::streaming_shapes::sse::{SseStreamBufferTrait, SseStreamIter};
|
||||
|
||||
#[test]
|
||||
fn test_chat_completions_passthrough_buffer() {
|
||||
|
|
@ -73,7 +79,7 @@ mod tests {
|
|||
buffer.add_transformed_event(event);
|
||||
}
|
||||
|
||||
let output_bytes = buffer.into_bytes();
|
||||
let output_bytes = buffer.to_bytes();
|
||||
let output = String::from_utf8_lossy(&output_bytes);
|
||||
|
||||
println!("\nTRANSFORMED OUTPUT (ChatCompletions - Passthrough):");
|
||||
|
|
@ -84,7 +90,11 @@ mod tests {
|
|||
assert!(!output_bytes.is_empty());
|
||||
assert!(output.contains("chatcmpl-123"));
|
||||
assert!(output.contains("[DONE]"));
|
||||
assert_eq!(raw_input.trim(), output.trim(), "Passthrough should preserve input");
|
||||
assert_eq!(
|
||||
raw_input.trim(),
|
||||
output.trim(),
|
||||
"Passthrough should preserve input"
|
||||
);
|
||||
|
||||
println!("\nVALIDATION SUMMARY:");
|
||||
println!("{}", "-".repeat(80));
|
||||
|
|
|
|||
|
|
@ -1,10 +1,10 @@
|
|||
use std::collections::HashMap;
|
||||
use log::debug;
|
||||
use crate::apis::openai_responses::{
|
||||
ResponsesAPIStreamEvent, ResponsesAPIResponse, OutputItem, OutputItemStatus,
|
||||
ResponseStatus, TextConfig, TextFormat, Reasoning,
|
||||
OutputItem, OutputItemStatus, Reasoning, ResponseStatus, ResponsesAPIResponse,
|
||||
ResponsesAPIStreamEvent, TextConfig, TextFormat,
|
||||
};
|
||||
use crate::apis::streaming_shapes::sse::{SseEvent, SseStreamBufferTrait};
|
||||
use log::debug;
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Helper to convert ResponseAPIStreamEvent to SseEvent
|
||||
fn event_to_sse(event: ResponsesAPIStreamEvent) -> SseEvent {
|
||||
|
|
@ -16,10 +16,17 @@ fn event_to_sse(event: ResponsesAPIStreamEvent) -> SseEvent {
|
|||
ResponsesAPIStreamEvent::ResponseOutputItemDone { .. } => "response.output_item.done",
|
||||
ResponsesAPIStreamEvent::ResponseOutputTextDelta { .. } => "response.output_text.delta",
|
||||
ResponsesAPIStreamEvent::ResponseOutputTextDone { .. } => "response.output_text.done",
|
||||
ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDelta { .. } => "response.function_call_arguments.delta",
|
||||
ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDone { .. } => "response.function_call_arguments.done",
|
||||
ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDelta { .. } => {
|
||||
"response.function_call_arguments.delta"
|
||||
}
|
||||
ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDone { .. } => {
|
||||
"response.function_call_arguments.done"
|
||||
}
|
||||
unknown => {
|
||||
debug!("Unknown ResponsesAPIStreamEvent type encountered: {:?}", unknown);
|
||||
debug!(
|
||||
"Unknown ResponsesAPIStreamEvent type encountered: {:?}",
|
||||
unknown
|
||||
);
|
||||
"unknown"
|
||||
}
|
||||
};
|
||||
|
|
@ -85,6 +92,12 @@ pub struct ResponsesAPIStreamBuffer {
|
|||
buffered_events: Vec<SseEvent>,
|
||||
}
|
||||
|
||||
impl Default for ResponsesAPIStreamBuffer {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl ResponsesAPIStreamBuffer {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
|
|
@ -112,7 +125,11 @@ impl ResponsesAPIStreamBuffer {
|
|||
}
|
||||
|
||||
fn generate_item_id(prefix: &str) -> String {
|
||||
format!("{}_{}", prefix, uuid::Uuid::new_v4().to_string().replace("-", ""))
|
||||
format!(
|
||||
"{}_{}",
|
||||
prefix,
|
||||
uuid::Uuid::new_v4().to_string().replace("-", "")
|
||||
)
|
||||
}
|
||||
|
||||
fn get_or_create_item_id(&mut self, output_index: i32, prefix: &str) -> String {
|
||||
|
|
@ -160,7 +177,13 @@ impl ResponsesAPIStreamBuffer {
|
|||
}
|
||||
|
||||
/// Create output_item.added event for tool call
|
||||
fn create_tool_call_added_event(&mut self, output_index: i32, item_id: &str, call_id: &str, name: &str) -> SseEvent {
|
||||
fn create_tool_call_added_event(
|
||||
&mut self,
|
||||
output_index: i32,
|
||||
item_id: &str,
|
||||
call_id: &str,
|
||||
name: &str,
|
||||
) -> SseEvent {
|
||||
let event = ResponsesAPIStreamEvent::ResponseOutputItemAdded {
|
||||
output_index,
|
||||
item: OutputItem::FunctionCall {
|
||||
|
|
@ -237,9 +260,15 @@ impl ResponsesAPIStreamBuffer {
|
|||
// Emit done events for all accumulated content
|
||||
|
||||
// Text content done events
|
||||
let text_items: Vec<_> = self.text_content.iter().map(|(id, content)| (id.clone(), content.clone())).collect();
|
||||
let text_items: Vec<_> = self
|
||||
.text_content
|
||||
.iter()
|
||||
.map(|(id, content)| (id.clone(), content.clone()))
|
||||
.collect();
|
||||
for (item_id, content) in text_items {
|
||||
let output_index = self.output_items_added.iter()
|
||||
let output_index = self
|
||||
.output_items_added
|
||||
.iter()
|
||||
.find(|(_, id)| **id == item_id)
|
||||
.map(|(idx, _)| *idx)
|
||||
.unwrap_or(0);
|
||||
|
|
@ -270,9 +299,15 @@ impl ResponsesAPIStreamBuffer {
|
|||
}
|
||||
|
||||
// Function call done events
|
||||
let func_items: Vec<_> = self.function_arguments.iter().map(|(id, args)| (id.clone(), args.clone())).collect();
|
||||
let func_items: Vec<_> = self
|
||||
.function_arguments
|
||||
.iter()
|
||||
.map(|(id, args)| (id.clone(), args.clone()))
|
||||
.collect();
|
||||
for (item_id, arguments) in func_items {
|
||||
let output_index = self.output_items_added.iter()
|
||||
let output_index = self
|
||||
.output_items_added
|
||||
.iter()
|
||||
.find(|(_, id)| **id == item_id)
|
||||
.map(|(idx, _)| *idx)
|
||||
.unwrap_or(0);
|
||||
|
|
@ -286,9 +321,16 @@ impl ResponsesAPIStreamBuffer {
|
|||
};
|
||||
events.push(event_to_sse(args_done_event));
|
||||
|
||||
let (call_id, name) = self.tool_call_metadata.get(&output_index)
|
||||
let (call_id, name) = self
|
||||
.tool_call_metadata
|
||||
.get(&output_index)
|
||||
.cloned()
|
||||
.unwrap_or_else(|| (format!("call_{}", uuid::Uuid::new_v4()), "unknown".to_string()));
|
||||
.unwrap_or_else(|| {
|
||||
(
|
||||
format!("call_{}", uuid::Uuid::new_v4()),
|
||||
"unknown".to_string(),
|
||||
)
|
||||
});
|
||||
|
||||
let seq2 = self.next_sequence_number();
|
||||
let item_done_event = ResponsesAPIStreamEvent::ResponseOutputItemDone {
|
||||
|
|
@ -315,9 +357,16 @@ impl ResponsesAPIStreamBuffer {
|
|||
if let Some(item_id) = self.output_items_added.get(&output_index) {
|
||||
// Check if this is a function call
|
||||
if let Some(arguments) = self.function_arguments.get(item_id) {
|
||||
let (call_id, name) = self.tool_call_metadata.get(&output_index)
|
||||
let (call_id, name) = self
|
||||
.tool_call_metadata
|
||||
.get(&output_index)
|
||||
.cloned()
|
||||
.unwrap_or_else(|| (format!("call_{}", uuid::Uuid::new_v4()), "unknown".to_string()));
|
||||
.unwrap_or_else(|| {
|
||||
(
|
||||
format!("call_{}", uuid::Uuid::new_v4()),
|
||||
"unknown".to_string(),
|
||||
)
|
||||
});
|
||||
|
||||
output_items.push(OutputItem::FunctionCall {
|
||||
id: item_id.clone(),
|
||||
|
|
@ -397,9 +446,9 @@ impl SseStreamBufferTrait for ResponsesAPIStreamBuffer {
|
|||
let mut events = Vec::new();
|
||||
|
||||
// Capture upstream metadata from ResponseCreated or ResponseInProgress if present
|
||||
match stream_event {
|
||||
ResponsesAPIStreamEvent::ResponseCreated { response, .. } |
|
||||
ResponsesAPIStreamEvent::ResponseInProgress { response, .. } => {
|
||||
match stream_event.as_ref() {
|
||||
ResponsesAPIStreamEvent::ResponseCreated { response, .. }
|
||||
| ResponsesAPIStreamEvent::ResponseInProgress { response, .. } => {
|
||||
if self.upstream_response_metadata.is_none() {
|
||||
// Store the full upstream response as our metadata template
|
||||
self.upstream_response_metadata = Some(response.clone());
|
||||
|
|
@ -418,11 +467,16 @@ impl SseStreamBufferTrait for ResponsesAPIStreamBuffer {
|
|||
if !self.created_emitted {
|
||||
// Initialize metadata from first event if needed
|
||||
if self.response_id.is_none() {
|
||||
self.response_id = Some(format!("resp_{}", uuid::Uuid::new_v4().to_string().replace("-", "")));
|
||||
self.created_at = Some(std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs() as i64);
|
||||
self.response_id = Some(format!(
|
||||
"resp_{}",
|
||||
uuid::Uuid::new_v4().to_string().replace("-", "")
|
||||
));
|
||||
self.created_at = Some(
|
||||
std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs() as i64,
|
||||
);
|
||||
self.model = Some("unknown".to_string()); // Will be set by caller if available
|
||||
}
|
||||
|
||||
|
|
@ -436,58 +490,95 @@ impl SseStreamBufferTrait for ResponsesAPIStreamBuffer {
|
|||
}
|
||||
|
||||
// Process the delta event
|
||||
match stream_event {
|
||||
ResponsesAPIStreamEvent::ResponseOutputTextDelta { output_index, delta, .. } => {
|
||||
match stream_event.as_ref() {
|
||||
ResponsesAPIStreamEvent::ResponseOutputTextDelta {
|
||||
output_index,
|
||||
delta,
|
||||
..
|
||||
} => {
|
||||
let item_id = self.get_or_create_item_id(*output_index, "msg");
|
||||
|
||||
// Emit output_item.added if this is the first time we see this output index
|
||||
if !self.output_items_added.contains_key(output_index) {
|
||||
self.output_items_added.insert(*output_index, item_id.clone());
|
||||
self.output_items_added
|
||||
.insert(*output_index, item_id.clone());
|
||||
events.push(self.create_output_item_added_event(*output_index, &item_id));
|
||||
}
|
||||
|
||||
// Accumulate text content
|
||||
self.text_content.entry(item_id.clone())
|
||||
self.text_content
|
||||
.entry(item_id.clone())
|
||||
.and_modify(|content| content.push_str(delta))
|
||||
.or_insert_with(|| delta.clone());
|
||||
|
||||
// Emit text delta with filled-in item_id and sequence_number
|
||||
let mut delta_event = stream_event.clone();
|
||||
if let ResponsesAPIStreamEvent::ResponseOutputTextDelta { item_id: ref mut id, sequence_number: ref mut seq, .. } = delta_event {
|
||||
let mut delta_event = stream_event.as_ref().clone();
|
||||
if let ResponsesAPIStreamEvent::ResponseOutputTextDelta {
|
||||
item_id: ref mut id,
|
||||
sequence_number: ref mut seq,
|
||||
..
|
||||
} = &mut delta_event
|
||||
{
|
||||
*id = item_id;
|
||||
*seq = self.next_sequence_number();
|
||||
}
|
||||
events.push(event_to_sse(delta_event));
|
||||
}
|
||||
ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDelta { output_index, delta, call_id, name, .. } => {
|
||||
ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDelta {
|
||||
output_index,
|
||||
delta,
|
||||
call_id,
|
||||
name,
|
||||
..
|
||||
} => {
|
||||
let item_id = self.get_or_create_item_id(*output_index, "fc");
|
||||
|
||||
// Store metadata if provided (from initial tool call event)
|
||||
if let (Some(cid), Some(n)) = (call_id, name) {
|
||||
self.tool_call_metadata.insert(*output_index, (cid.clone(), n.clone()));
|
||||
self.tool_call_metadata
|
||||
.insert(*output_index, (cid.clone(), n.clone()));
|
||||
}
|
||||
|
||||
// Emit output_item.added if this is the first time we see this tool call
|
||||
if !self.output_items_added.contains_key(output_index) {
|
||||
self.output_items_added.insert(*output_index, item_id.clone());
|
||||
self.output_items_added
|
||||
.insert(*output_index, item_id.clone());
|
||||
|
||||
// For tool calls, we need call_id and name from metadata
|
||||
// These should now be populated from the event itself
|
||||
let (call_id, name) = self.tool_call_metadata.get(output_index)
|
||||
let (call_id, name) = self
|
||||
.tool_call_metadata
|
||||
.get(output_index)
|
||||
.cloned()
|
||||
.unwrap_or_else(|| (format!("call_{}", uuid::Uuid::new_v4()), "unknown".to_string()));
|
||||
.unwrap_or_else(|| {
|
||||
(
|
||||
format!("call_{}", uuid::Uuid::new_v4()),
|
||||
"unknown".to_string(),
|
||||
)
|
||||
});
|
||||
|
||||
events.push(self.create_tool_call_added_event(*output_index, &item_id, &call_id, &name));
|
||||
events.push(self.create_tool_call_added_event(
|
||||
*output_index,
|
||||
&item_id,
|
||||
&call_id,
|
||||
&name,
|
||||
));
|
||||
}
|
||||
|
||||
// Accumulate function arguments
|
||||
self.function_arguments.entry(item_id.clone())
|
||||
self.function_arguments
|
||||
.entry(item_id.clone())
|
||||
.and_modify(|args| args.push_str(delta))
|
||||
.or_insert_with(|| delta.clone());
|
||||
|
||||
// Emit function call arguments delta with filled-in item_id and sequence_number
|
||||
let mut delta_event = stream_event.clone();
|
||||
if let ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDelta { item_id: ref mut id, sequence_number: ref mut seq, .. } = delta_event {
|
||||
let mut delta_event = stream_event.as_ref().clone();
|
||||
if let ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDelta {
|
||||
item_id: ref mut id,
|
||||
sequence_number: ref mut seq,
|
||||
..
|
||||
} = &mut delta_event
|
||||
{
|
||||
*id = item_id;
|
||||
*seq = self.next_sequence_number();
|
||||
}
|
||||
|
|
@ -495,7 +586,7 @@ impl SseStreamBufferTrait for ResponsesAPIStreamBuffer {
|
|||
}
|
||||
_ => {
|
||||
// For other event types, just pass through with sequence number
|
||||
let other_event = stream_event.clone();
|
||||
let other_event = stream_event.as_ref().clone();
|
||||
// TODO: Add sequence number to other event types if needed
|
||||
events.push(event_to_sse(other_event));
|
||||
}
|
||||
|
|
@ -505,8 +596,7 @@ impl SseStreamBufferTrait for ResponsesAPIStreamBuffer {
|
|||
self.buffered_events.extend(events);
|
||||
}
|
||||
|
||||
|
||||
fn into_bytes(&mut self) -> Vec<u8> {
|
||||
fn to_bytes(&mut self) -> Vec<u8> {
|
||||
// For Responses API, we need special handling:
|
||||
// - Most events are already in buffered_events from add_transformed_event
|
||||
// - We should NOT finalize here - finalization happens when we detect [DONE] or end of stream
|
||||
|
|
@ -525,9 +615,9 @@ impl SseStreamBufferTrait for ResponsesAPIStreamBuffer {
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::clients::{SupportedAPIsFromClient, SupportedUpstreamAPIs};
|
||||
use crate::apis::openai::OpenAIApi;
|
||||
use crate::apis::streaming_shapes::sse::SseStreamIter;
|
||||
use crate::clients::{SupportedAPIsFromClient, SupportedUpstreamAPIs};
|
||||
|
||||
#[test]
|
||||
fn test_chat_completions_to_responses_api_transformation() {
|
||||
|
|
@ -557,11 +647,12 @@ mod tests {
|
|||
|
||||
for raw_event in stream_iter {
|
||||
// Transform the event using the client/upstream APIs
|
||||
let transformed_event = SseEvent::try_from((raw_event, &client_api, &upstream_api)).unwrap();
|
||||
let transformed_event =
|
||||
SseEvent::try_from((raw_event, &client_api, &upstream_api)).unwrap();
|
||||
buffer.add_transformed_event(transformed_event);
|
||||
}
|
||||
|
||||
let output_bytes = buffer.into_bytes();
|
||||
let output_bytes = buffer.to_bytes();
|
||||
let output = String::from_utf8_lossy(&output_bytes);
|
||||
|
||||
println!("\nTRANSFORMED OUTPUT (ResponsesAPI):");
|
||||
|
|
@ -570,13 +661,34 @@ mod tests {
|
|||
|
||||
// Assertions
|
||||
assert!(!output_bytes.is_empty(), "Should have output");
|
||||
assert!(output.contains("response.created"), "Should have response.created");
|
||||
assert!(output.contains("response.in_progress"), "Should have response.in_progress");
|
||||
assert!(output.contains("response.output_item.added"), "Should have output_item.added");
|
||||
assert!(output.contains("response.output_text.delta"), "Should have text deltas");
|
||||
assert!(output.contains("response.output_text.done"), "Should have text.done");
|
||||
assert!(output.contains("response.output_item.done"), "Should have output_item.done");
|
||||
assert!(output.contains("response.completed"), "Should have response.completed");
|
||||
assert!(
|
||||
output.contains("response.created"),
|
||||
"Should have response.created"
|
||||
);
|
||||
assert!(
|
||||
output.contains("response.in_progress"),
|
||||
"Should have response.in_progress"
|
||||
);
|
||||
assert!(
|
||||
output.contains("response.output_item.added"),
|
||||
"Should have output_item.added"
|
||||
);
|
||||
assert!(
|
||||
output.contains("response.output_text.delta"),
|
||||
"Should have text deltas"
|
||||
);
|
||||
assert!(
|
||||
output.contains("response.output_text.done"),
|
||||
"Should have text.done"
|
||||
);
|
||||
assert!(
|
||||
output.contains("response.output_item.done"),
|
||||
"Should have output_item.done"
|
||||
);
|
||||
assert!(
|
||||
output.contains("response.completed"),
|
||||
"Should have response.completed"
|
||||
);
|
||||
|
||||
println!("\nVALIDATION SUMMARY:");
|
||||
println!("{}", "-".repeat(80));
|
||||
|
|
@ -616,7 +728,7 @@ mod tests {
|
|||
buffer.add_transformed_event(transformed);
|
||||
}
|
||||
|
||||
let output_bytes = buffer.into_bytes();
|
||||
let output_bytes = buffer.to_bytes();
|
||||
let output = String::from_utf8_lossy(&output_bytes);
|
||||
|
||||
println!("\nTRANSFORMED OUTPUT (ResponsesAPI):");
|
||||
|
|
@ -624,24 +736,55 @@ mod tests {
|
|||
println!("{}", output);
|
||||
|
||||
// Assertions
|
||||
assert!(output.contains("response.created"), "Should have response.created");
|
||||
assert!(output.contains("response.in_progress"), "Should have response.in_progress");
|
||||
assert!(output.contains("response.output_item.added"), "Should have output_item.added");
|
||||
assert!(output.contains("\"type\":\"function_call\""), "Should be function_call type");
|
||||
assert!(output.contains("\"name\":\"get_weather\""), "Should have function name");
|
||||
assert!(output.contains("\"call_id\":\"call_mD5ggLKk3SMKGPFqFdcpKg6q\""), "Should have correct call_id");
|
||||
assert!(
|
||||
output.contains("response.created"),
|
||||
"Should have response.created"
|
||||
);
|
||||
assert!(
|
||||
output.contains("response.in_progress"),
|
||||
"Should have response.in_progress"
|
||||
);
|
||||
assert!(
|
||||
output.contains("response.output_item.added"),
|
||||
"Should have output_item.added"
|
||||
);
|
||||
assert!(
|
||||
output.contains("\"type\":\"function_call\""),
|
||||
"Should be function_call type"
|
||||
);
|
||||
assert!(
|
||||
output.contains("\"name\":\"get_weather\""),
|
||||
"Should have function name"
|
||||
);
|
||||
assert!(
|
||||
output.contains("\"call_id\":\"call_mD5ggLKk3SMKGPFqFdcpKg6q\""),
|
||||
"Should have correct call_id"
|
||||
);
|
||||
|
||||
let delta_count = output.matches("event: response.function_call_arguments.delta").count();
|
||||
let delta_count = output
|
||||
.matches("event: response.function_call_arguments.delta")
|
||||
.count();
|
||||
assert_eq!(delta_count, 4, "Should have 4 delta events");
|
||||
|
||||
assert!(!output.contains("response.function_call_arguments.done"), "Should NOT have arguments.done");
|
||||
assert!(!output.contains("response.output_item.done"), "Should NOT have output_item.done");
|
||||
assert!(!output.contains("response.completed"), "Should NOT have response.completed");
|
||||
assert!(
|
||||
!output.contains("response.function_call_arguments.done"),
|
||||
"Should NOT have arguments.done"
|
||||
);
|
||||
assert!(
|
||||
!output.contains("response.output_item.done"),
|
||||
"Should NOT have output_item.done"
|
||||
);
|
||||
assert!(
|
||||
!output.contains("response.completed"),
|
||||
"Should NOT have response.completed"
|
||||
);
|
||||
|
||||
println!("\nVALIDATION SUMMARY:");
|
||||
println!("{}", "-".repeat(80));
|
||||
println!("✓ Lifecycle events: response.created, response.in_progress");
|
||||
println!("✓ Function call metadata: name='get_weather', call_id='call_mD5ggLKk3SMKGPFqFdcpKg6q'");
|
||||
println!(
|
||||
"✓ Function call metadata: name='get_weather', call_id='call_mD5ggLKk3SMKGPFqFdcpKg6q'"
|
||||
);
|
||||
println!("✓ Incremental deltas: 4 events (1 initial + 3 argument chunks)");
|
||||
println!("✓ NO completion events (partial stream, no [DONE])");
|
||||
println!("✓ Arguments accumulated: '{{\"location\":\"'\n");
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
use crate::providers::streaming_response::ProviderStreamResponse;
|
||||
use crate::providers::streaming_response::ProviderStreamResponseType;
|
||||
use crate::apis::streaming_shapes::chat_completions_streaming_buffer::OpenAIChatCompletionsStreamBuffer;
|
||||
use crate::apis::streaming_shapes::anthropic_streaming_buffer::AnthropicMessagesStreamBuffer;
|
||||
use crate::apis::streaming_shapes::chat_completions_streaming_buffer::OpenAIChatCompletionsStreamBuffer;
|
||||
use crate::apis::streaming_shapes::passthrough_streaming_buffer::PassthroughStreamBuffer;
|
||||
use crate::apis::streaming_shapes::responses_api_streaming_buffer::ResponsesAPIStreamBuffer;
|
||||
use crate::providers::streaming_response::ProviderStreamResponse;
|
||||
use crate::providers::streaming_response::ProviderStreamResponseType;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::error::Error;
|
||||
use std::fmt;
|
||||
|
|
@ -37,7 +37,7 @@ pub trait SseStreamBufferTrait: Send + Sync {
|
|||
///
|
||||
/// # Returns
|
||||
/// Bytes ready for wire transmission (may be empty if no events were accumulated)
|
||||
fn into_bytes(&mut self) -> Vec<u8>;
|
||||
fn to_bytes(&mut self) -> Vec<u8>;
|
||||
}
|
||||
|
||||
/// Unified SSE Stream Buffer enum that provides a zero-cost abstraction
|
||||
|
|
@ -45,7 +45,7 @@ pub enum SseStreamBuffer {
|
|||
Passthrough(PassthroughStreamBuffer),
|
||||
OpenAIChatCompletions(OpenAIChatCompletionsStreamBuffer),
|
||||
AnthropicMessages(AnthropicMessagesStreamBuffer),
|
||||
OpenAIResponses(ResponsesAPIStreamBuffer),
|
||||
OpenAIResponses(Box<ResponsesAPIStreamBuffer>),
|
||||
}
|
||||
|
||||
impl SseStreamBufferTrait for SseStreamBuffer {
|
||||
|
|
@ -58,12 +58,12 @@ impl SseStreamBufferTrait for SseStreamBuffer {
|
|||
}
|
||||
}
|
||||
|
||||
fn into_bytes(&mut self) -> Vec<u8> {
|
||||
fn to_bytes(&mut self) -> Vec<u8> {
|
||||
match self {
|
||||
Self::Passthrough(buffer) => buffer.into_bytes(),
|
||||
Self::OpenAIChatCompletions(buffer) => buffer.into_bytes(),
|
||||
Self::AnthropicMessages(buffer) => buffer.into_bytes(),
|
||||
Self::OpenAIResponses(buffer) => buffer.into_bytes(),
|
||||
Self::Passthrough(buffer) => buffer.to_bytes(),
|
||||
Self::OpenAIChatCompletions(buffer) => buffer.to_bytes(),
|
||||
Self::AnthropicMessages(buffer) => buffer.to_bytes(),
|
||||
Self::OpenAIResponses(buffer) => buffer.to_bytes(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -99,7 +99,7 @@ impl SseEvent {
|
|||
let sse_string: String = response.clone().into();
|
||||
|
||||
SseEvent {
|
||||
data: None, // Data is embedded in sse_transformed_lines
|
||||
data: None, // Data is embedded in sse_transformed_lines
|
||||
event: None, // Event type is embedded in sse_transformed_lines
|
||||
raw_line: sse_string.clone(),
|
||||
sse_transformed_lines: sse_string,
|
||||
|
|
@ -149,10 +149,8 @@ impl FromStr for SseEvent {
|
|||
});
|
||||
}
|
||||
|
||||
if trimmed_line.starts_with("data: ") {
|
||||
let data: String = trimmed_line[6..].to_string(); // Remove "data: " prefix
|
||||
// Allow empty data content after "data: " prefix
|
||||
// This handles cases like "data: " followed by newline
|
||||
if let Some(stripped) = trimmed_line.strip_prefix("data: ") {
|
||||
let data: String = stripped.to_string();
|
||||
if data.trim().is_empty() {
|
||||
return Err(SseParseError {
|
||||
message: "Empty data field after 'data: ' prefix".to_string(),
|
||||
|
|
@ -166,8 +164,8 @@ impl FromStr for SseEvent {
|
|||
sse_transformed_lines: line.to_string(),
|
||||
provider_stream_response: None,
|
||||
})
|
||||
} else if trimmed_line.starts_with("event: ") {
|
||||
let event_type = trimmed_line[7..].to_string();
|
||||
} else if let Some(stripped) = trimmed_line.strip_prefix("event: ") {
|
||||
let event_type = stripped.to_string();
|
||||
if event_type.is_empty() {
|
||||
return Err(SseParseError {
|
||||
message: "Empty event field is not a valid SSE event".to_string(),
|
||||
|
|
@ -183,7 +181,10 @@ impl FromStr for SseEvent {
|
|||
})
|
||||
} else {
|
||||
Err(SseParseError {
|
||||
message: format!("Line does not start with 'data: ' or 'event: ': {}", trimmed_line),
|
||||
message: format!(
|
||||
"Line does not start with 'data: ' or 'event: ': {}",
|
||||
trimmed_line
|
||||
),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
@ -196,16 +197,16 @@ impl fmt::Display for SseEvent {
|
|||
}
|
||||
|
||||
// Into implementation to convert SseEvent to bytes for response buffer
|
||||
impl Into<Vec<u8>> for SseEvent {
|
||||
fn into(self) -> Vec<u8> {
|
||||
impl From<SseEvent> for Vec<u8> {
|
||||
fn from(val: SseEvent) -> Self {
|
||||
// For generated events (like ResponsesAPI), sse_transformed_lines already includes trailing \n\n
|
||||
// For parsed events (like passthrough), we need to add the \n\n separator
|
||||
if self.sse_transformed_lines.ends_with("\n\n") {
|
||||
if val.sse_transformed_lines.ends_with("\n\n") {
|
||||
// Already properly formatted with trailing newlines
|
||||
self.sse_transformed_lines.into_bytes()
|
||||
val.sse_transformed_lines.into_bytes()
|
||||
} else {
|
||||
// Add SSE event separator
|
||||
format!("{}\n\n", self.sse_transformed_lines).into_bytes()
|
||||
format!("{}\n\n", val.sse_transformed_lines).into_bytes()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -10,6 +10,12 @@ pub struct SseChunkProcessor {
|
|||
incomplete_event_buffer: Vec<u8>,
|
||||
}
|
||||
|
||||
impl Default for SseChunkProcessor {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl SseChunkProcessor {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
|
|
@ -93,8 +99,8 @@ impl SseChunkProcessor {
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::clients::endpoints::{SupportedAPIsFromClient, SupportedUpstreamAPIs};
|
||||
use crate::apis::openai::OpenAIApi;
|
||||
use crate::clients::endpoints::{SupportedAPIsFromClient, SupportedUpstreamAPIs};
|
||||
|
||||
#[test]
|
||||
fn test_complete_events_process_immediately() {
|
||||
|
|
@ -104,7 +110,9 @@ mod tests {
|
|||
|
||||
let chunk1 = b"data: {\"id\":\"chatcmpl-123\",\"object\":\"chat.completion.chunk\",\"created\":1234567890,\"model\":\"gpt-4o\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"Hello\"},\"finish_reason\":null}]}\n\n";
|
||||
|
||||
let events = processor.process_chunk(chunk1, &client_api, &upstream_api).unwrap();
|
||||
let events = processor
|
||||
.process_chunk(chunk1, &client_api, &upstream_api)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(events.len(), 1);
|
||||
assert!(!processor.has_buffered_data());
|
||||
|
|
@ -119,18 +127,28 @@ mod tests {
|
|||
// First chunk with incomplete JSON
|
||||
let chunk1 = b"data: {\"id\":\"chatcmpl-123\",\"object\":\"chat.completion.chu";
|
||||
|
||||
let events1 = processor.process_chunk(chunk1, &client_api, &upstream_api).unwrap();
|
||||
let events1 = processor
|
||||
.process_chunk(chunk1, &client_api, &upstream_api)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(events1.len(), 0, "Incomplete event should not be processed");
|
||||
assert!(processor.has_buffered_data(), "Incomplete data should be buffered");
|
||||
assert!(
|
||||
processor.has_buffered_data(),
|
||||
"Incomplete data should be buffered"
|
||||
);
|
||||
|
||||
// Second chunk completes the JSON
|
||||
let chunk2 = b"nk\",\"created\":1234567890,\"model\":\"gpt-4o\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"Hello\"},\"finish_reason\":null}]}\n\n";
|
||||
|
||||
let events2 = processor.process_chunk(chunk2, &client_api, &upstream_api).unwrap();
|
||||
let events2 = processor
|
||||
.process_chunk(chunk2, &client_api, &upstream_api)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(events2.len(), 1, "Complete event should be processed");
|
||||
assert!(!processor.has_buffered_data(), "Buffer should be cleared after completion");
|
||||
assert!(
|
||||
!processor.has_buffered_data(),
|
||||
"Buffer should be cleared after completion"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -142,10 +160,15 @@ mod tests {
|
|||
// Chunk with 2 complete events and 1 incomplete
|
||||
let chunk = b"data: {\"id\":\"chatcmpl-123\",\"object\":\"chat.completion.chunk\",\"created\":1234567890,\"model\":\"gpt-4o\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"A\"},\"finish_reason\":null}]}\n\ndata: {\"id\":\"chatcmpl-124\",\"object\":\"chat.completion.chunk\",\"created\":1234567890,\"model\":\"gpt-4o\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"B\"},\"finish_reason\":null}]}\n\ndata: {\"id\":\"chatcmpl-125\",\"object\":\"chat.completion.chu";
|
||||
|
||||
let events = processor.process_chunk(chunk, &client_api, &upstream_api).unwrap();
|
||||
let events = processor
|
||||
.process_chunk(chunk, &client_api, &upstream_api)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(events.len(), 2, "Two complete events should be processed");
|
||||
assert!(processor.has_buffered_data(), "Incomplete third event should be buffered");
|
||||
assert!(
|
||||
processor.has_buffered_data(),
|
||||
"Incomplete third event should be buffered"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -171,11 +194,23 @@ data: {"type":"content_block_stop","index":0}
|
|||
Ok(events) => {
|
||||
println!("Successfully processed {} events", events.len());
|
||||
for (i, event) in events.iter().enumerate() {
|
||||
println!("Event {}: event={:?}, has_data={}", i, event.event, event.data.is_some());
|
||||
println!(
|
||||
"Event {}: event={:?}, has_data={}",
|
||||
i,
|
||||
event.event,
|
||||
event.data.is_some()
|
||||
);
|
||||
}
|
||||
// Should successfully process both events (signature_delta + content_block_stop)
|
||||
assert!(events.len() >= 2, "Should process at least 2 complete events (signature_delta + stop), got {}", events.len());
|
||||
assert!(!processor.has_buffered_data(), "Complete events should not be buffered");
|
||||
assert!(
|
||||
events.len() >= 2,
|
||||
"Should process at least 2 complete events (signature_delta + stop), got {}",
|
||||
events.len()
|
||||
);
|
||||
assert!(
|
||||
!processor.has_buffered_data(),
|
||||
"Complete events should not be buffered"
|
||||
);
|
||||
}
|
||||
Err(e) => {
|
||||
panic!("Failed to process signature_delta chunk - this means SignatureDelta is not properly handled: {}", e);
|
||||
|
|
@ -194,12 +229,21 @@ data: {"type":"content_block_stop","index":0}
|
|||
// Second event is valid and should be processed
|
||||
let chunk = b"data: {\"id\":\"chatcmpl-123\",\"object\":\"chat.completion.chunk\",\"created\":1234567890,\"model\":\"gpt-4o\",\"choices\":[{\"index\":0,\"delta\":{\"unsupported_field_causing_validation_error\":true},\"finish_reason\":null}]}\n\ndata: {\"id\":\"chatcmpl-124\",\"object\":\"chat.completion.chunk\",\"created\":1234567890,\"model\":\"gpt-4o\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"Hello\"},\"finish_reason\":null}]}\n\n";
|
||||
|
||||
let events = processor.process_chunk(chunk, &client_api, &upstream_api).unwrap();
|
||||
let events = processor
|
||||
.process_chunk(chunk, &client_api, &upstream_api)
|
||||
.unwrap();
|
||||
|
||||
// Should skip the invalid event and process the valid one
|
||||
// (If we were buffering all errors, we'd get 0 events and have buffered data)
|
||||
assert!(events.len() >= 1, "Should process at least the valid event, got {} events", events.len());
|
||||
assert!(!processor.has_buffered_data(), "Invalid (non-incomplete) events should not be buffered");
|
||||
assert!(
|
||||
!events.is_empty(),
|
||||
"Should process at least the valid event, got {} events",
|
||||
events.len()
|
||||
);
|
||||
assert!(
|
||||
!processor.has_buffered_data(),
|
||||
"Invalid (non-incomplete) events should not be buffered"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -227,14 +271,27 @@ data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text
|
|||
|
||||
match result {
|
||||
Ok(events) => {
|
||||
println!("Processed {} events (unsupported event should be skipped)", events.len());
|
||||
println!(
|
||||
"Processed {} events (unsupported event should be skipped)",
|
||||
events.len()
|
||||
);
|
||||
// Should process the 2 valid text_delta events and skip the unsupported one
|
||||
// We expect at least 2 events (the valid ones), unsupported should be skipped
|
||||
assert!(events.len() >= 2, "Should process at least 2 valid events, got {}", events.len());
|
||||
assert!(!processor.has_buffered_data(), "Unsupported events should be skipped, not buffered");
|
||||
assert!(
|
||||
events.len() >= 2,
|
||||
"Should process at least 2 valid events, got {}",
|
||||
events.len()
|
||||
);
|
||||
assert!(
|
||||
!processor.has_buffered_data(),
|
||||
"Unsupported events should be skipped, not buffered"
|
||||
);
|
||||
}
|
||||
Err(e) => {
|
||||
panic!("Should not fail on unsupported delta type, should skip it: {}", e);
|
||||
panic!(
|
||||
"Should not fail on unsupported delta type, should skip it: {}",
|
||||
e
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -135,7 +135,10 @@ impl SupportedAPIsFromClient {
|
|||
ProviderId::AzureOpenAI => {
|
||||
if request_path.starts_with("/v1/") {
|
||||
let suffix = endpoint_suffix.trim_start_matches('/');
|
||||
build_endpoint("/openai/deployments", &format!("/{}/{}?api-version=2025-01-01-preview", model_id, suffix))
|
||||
build_endpoint(
|
||||
"/openai/deployments",
|
||||
&format!("/{}/{}?api-version=2025-01-01-preview", model_id, suffix),
|
||||
)
|
||||
} else {
|
||||
build_endpoint("/v1", endpoint_suffix)
|
||||
}
|
||||
|
|
@ -163,19 +166,21 @@ impl SupportedAPIsFromClient {
|
|||
};
|
||||
|
||||
match self {
|
||||
SupportedAPIsFromClient::AnthropicMessagesAPI(AnthropicApi::Messages) => match provider_id {
|
||||
ProviderId::Anthropic => build_endpoint("/v1", "/messages"),
|
||||
ProviderId::AmazonBedrock => {
|
||||
if request_path.starts_with("/v1/") && !is_streaming {
|
||||
build_endpoint("", &format!("/model/{}/converse", model_id))
|
||||
} else if request_path.starts_with("/v1/") && is_streaming {
|
||||
build_endpoint("", &format!("/model/{}/converse-stream", model_id))
|
||||
} else {
|
||||
build_endpoint("/v1", "/chat/completions")
|
||||
SupportedAPIsFromClient::AnthropicMessagesAPI(AnthropicApi::Messages) => {
|
||||
match provider_id {
|
||||
ProviderId::Anthropic => build_endpoint("/v1", "/messages"),
|
||||
ProviderId::AmazonBedrock => {
|
||||
if request_path.starts_with("/v1/") && !is_streaming {
|
||||
build_endpoint("", &format!("/model/{}/converse", model_id))
|
||||
} else if request_path.starts_with("/v1/") && is_streaming {
|
||||
build_endpoint("", &format!("/model/{}/converse-stream", model_id))
|
||||
} else {
|
||||
build_endpoint("/v1", "/chat/completions")
|
||||
}
|
||||
}
|
||||
_ => build_endpoint("/v1", "/chat/completions"),
|
||||
}
|
||||
_ => build_endpoint("/v1", "/chat/completions"),
|
||||
},
|
||||
}
|
||||
SupportedAPIsFromClient::OpenAIResponsesAPI(_) => {
|
||||
// For Responses API, check if provider supports it, otherwise translate to chat/completions
|
||||
match provider_id {
|
||||
|
|
@ -193,7 +198,6 @@ impl SupportedAPIsFromClient {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
impl SupportedUpstreamAPIs {
|
||||
/// Create a SupportedUpstreamApi from an endpoint path
|
||||
pub fn from_endpoint(endpoint: &str) -> Option<Self> {
|
||||
|
|
@ -216,17 +220,17 @@ impl SupportedUpstreamAPIs {
|
|||
return Some(SupportedUpstreamAPIs::AmazonBedrockConverse(bedrock_api))
|
||||
}
|
||||
AmazonBedrockApi::ConverseStream => {
|
||||
return Some(SupportedUpstreamAPIs::AmazonBedrockConverseStream(bedrock_api))
|
||||
return Some(SupportedUpstreamAPIs::AmazonBedrockConverseStream(
|
||||
bedrock_api,
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
/// Get all supported endpoint paths
|
||||
pub fn supported_endpoints() -> Vec<&'static str> {
|
||||
let mut endpoints = Vec::new();
|
||||
|
|
@ -269,9 +273,9 @@ mod tests {
|
|||
assert!(SupportedAPIsFromClient::from_endpoint("/v1/messages").is_some());
|
||||
|
||||
// Unsupported endpoints
|
||||
assert!(!SupportedAPIsFromClient::from_endpoint("/v1/unknown").is_some());
|
||||
assert!(!SupportedAPIsFromClient::from_endpoint("/v2/chat").is_some());
|
||||
assert!(!SupportedAPIsFromClient::from_endpoint("").is_some());
|
||||
assert!(SupportedAPIsFromClient::from_endpoint("/v1/unknown").is_none());
|
||||
assert!(SupportedAPIsFromClient::from_endpoint("/v2/chat").is_none());
|
||||
assert!(SupportedAPIsFromClient::from_endpoint("").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
|
|||
|
|
@ -12,11 +12,9 @@ pub use aws_smithy_eventstream::frame::DecodedFrame;
|
|||
pub use providers::id::ProviderId;
|
||||
pub use providers::request::{ProviderRequest, ProviderRequestError, ProviderRequestType};
|
||||
pub use providers::response::{
|
||||
ProviderResponse, ProviderResponseType, TokenUsage, ProviderResponseError
|
||||
};
|
||||
pub use providers::streaming_response::{
|
||||
ProviderStreamResponse, ProviderStreamResponseType
|
||||
ProviderResponse, ProviderResponseError, ProviderResponseType, TokenUsage,
|
||||
};
|
||||
pub use providers::streaming_response::{ProviderStreamResponse, ProviderStreamResponseType};
|
||||
|
||||
//TODO: Refactor such that commons doesn't depend on Hermes. For now this will clean up strings
|
||||
pub const CHAT_COMPLETIONS_PATH: &str = "/v1/chat/completions";
|
||||
|
|
@ -87,11 +85,17 @@ mod tests {
|
|||
let done_event = streaming_iter.next();
|
||||
assert!(done_event.is_some(), "Should get [DONE] event");
|
||||
let done_event = done_event.unwrap();
|
||||
assert!(done_event.is_done(), "[DONE] event should be marked as done");
|
||||
assert!(
|
||||
done_event.is_done(),
|
||||
"[DONE] event should be marked as done"
|
||||
);
|
||||
|
||||
// After [DONE], iterator should return None
|
||||
let final_event = streaming_iter.next();
|
||||
assert!(final_event.is_none(), "Iterator should return None after [DONE]");
|
||||
assert!(
|
||||
final_event.is_none(),
|
||||
"Iterator should return None after [DONE]"
|
||||
);
|
||||
}
|
||||
|
||||
/// Test AWS Event Stream decoding for Bedrock ConverseStream responses.
|
||||
|
|
@ -130,7 +134,7 @@ mod tests {
|
|||
let mut content_chunks = Vec::new();
|
||||
|
||||
// Simulate chunked network arrivals - process as data comes in
|
||||
let chunk_sizes = vec![50, 100, 75, 200, 150, 300, 500, 1000];
|
||||
let chunk_sizes = [50, 100, 75, 200, 150, 300, 500, 1000];
|
||||
let mut offset = 0;
|
||||
let mut chunk_num = 0;
|
||||
|
||||
|
|
|
|||
|
|
@ -59,10 +59,9 @@ impl ProviderId {
|
|||
(ProviderId::Anthropic, SupportedAPIsFromClient::AnthropicMessagesAPI(_)) => {
|
||||
SupportedUpstreamAPIs::AnthropicMessagesAPI(AnthropicApi::Messages)
|
||||
}
|
||||
(
|
||||
ProviderId::Anthropic,
|
||||
SupportedAPIsFromClient::OpenAIChatCompletions(_),
|
||||
) => SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
|
||||
(ProviderId::Anthropic, SupportedAPIsFromClient::OpenAIChatCompletions(_)) => {
|
||||
SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions)
|
||||
}
|
||||
|
||||
// Anthropic doesn't support Responses API, fall back to chat completions
|
||||
(ProviderId::Anthropic, SupportedAPIsFromClient::OpenAIResponsesAPI(_)) => {
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ use serde_json::Value;
|
|||
use std::collections::HashMap;
|
||||
use std::error::Error;
|
||||
use std::fmt;
|
||||
#[allow(clippy::large_enum_variant)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub enum ProviderRequestType {
|
||||
ChatCompletionsRequest(ChatCompletionsRequest),
|
||||
|
|
@ -197,7 +198,9 @@ impl ProviderRequest for ProviderRequestType {
|
|||
impl TryFrom<(&[u8], &SupportedAPIsFromClient)> for ProviderRequestType {
|
||||
type Error = std::io::Error;
|
||||
|
||||
fn try_from((bytes, client_api): (&[u8], &SupportedAPIsFromClient)) -> Result<Self, Self::Error> {
|
||||
fn try_from(
|
||||
(bytes, client_api): (&[u8], &SupportedAPIsFromClient),
|
||||
) -> Result<Self, Self::Error> {
|
||||
// Use SupportedApi to determine the appropriate request type
|
||||
match client_api {
|
||||
SupportedAPIsFromClient::OpenAIChatCompletions(_) => {
|
||||
|
|
@ -882,7 +885,7 @@ mod tests {
|
|||
ProviderRequestType::BedrockConverse(bedrock_req) => {
|
||||
assert_eq!(bedrock_req.model_id, "gpt-4o");
|
||||
// Bedrock receives the converted request through ChatCompletions
|
||||
assert!(!bedrock_req.messages.is_none());
|
||||
assert!(bedrock_req.messages.is_some());
|
||||
}
|
||||
_ => panic!("Expected BedrockConverse variant"),
|
||||
}
|
||||
|
|
@ -913,7 +916,9 @@ mod tests {
|
|||
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err();
|
||||
assert!(err.message.contains("ResponsesAPI can only be used as a client API"));
|
||||
assert!(err
|
||||
.message
|
||||
.contains("ResponsesAPI can only be used as a client API"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -953,7 +958,9 @@ mod tests {
|
|||
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err();
|
||||
assert!(err.message.contains("ResponsesAPI can only be used as a client API"));
|
||||
assert!(err
|
||||
.message
|
||||
.contains("ResponsesAPI can only be used as a client API"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -1023,9 +1030,7 @@ mod tests {
|
|||
role: MessagesRole::User,
|
||||
content: MessagesMessageContent::Single("Hello!".to_string()),
|
||||
}],
|
||||
system: Some(MessagesSystemPrompt::Single(
|
||||
"You are helpful".to_string(),
|
||||
)),
|
||||
system: Some(MessagesSystemPrompt::Single("You are helpful".to_string())),
|
||||
max_tokens: 100,
|
||||
container: None,
|
||||
mcp_servers: None,
|
||||
|
|
@ -1046,14 +1051,8 @@ mod tests {
|
|||
|
||||
// Should have system message + user message
|
||||
assert_eq!(messages.len(), 2);
|
||||
assert_eq!(
|
||||
messages[0].role,
|
||||
crate::apis::openai::Role::System
|
||||
);
|
||||
assert_eq!(
|
||||
messages[1].role,
|
||||
crate::apis::openai::Role::User
|
||||
);
|
||||
assert_eq!(messages[0].role, crate::apis::openai::Role::System);
|
||||
assert_eq!(messages[1].role, crate::apis::openai::Role::User);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -1094,13 +1093,7 @@ mod tests {
|
|||
|
||||
// Should have system message (instructions) + user message (input)
|
||||
assert_eq!(messages.len(), 2);
|
||||
assert_eq!(
|
||||
messages[0].role,
|
||||
crate::apis::openai::Role::System
|
||||
);
|
||||
assert_eq!(
|
||||
messages[1].role,
|
||||
crate::apis::openai::Role::User
|
||||
);
|
||||
assert_eq!(messages[0].role, crate::apis::openai::Role::System);
|
||||
assert_eq!(messages[1].role, crate::apis::openai::Role::User);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,7 +1,3 @@
|
|||
use serde::Serialize;
|
||||
use std::convert::TryFrom;
|
||||
use std::error::Error;
|
||||
use std::fmt;
|
||||
use crate::apis::amazon_bedrock::ConverseResponse;
|
||||
use crate::apis::anthropic::MessagesResponse;
|
||||
use crate::apis::openai::ChatCompletionsResponse;
|
||||
|
|
@ -9,14 +5,17 @@ use crate::apis::openai_responses::ResponsesAPIResponse;
|
|||
use crate::clients::endpoints::SupportedAPIsFromClient;
|
||||
use crate::clients::endpoints::SupportedUpstreamAPIs;
|
||||
use crate::providers::id::ProviderId;
|
||||
|
||||
use serde::Serialize;
|
||||
use std::convert::TryFrom;
|
||||
use std::error::Error;
|
||||
use std::fmt;
|
||||
|
||||
#[derive(Serialize, Debug, Clone)]
|
||||
#[serde(untagged)]
|
||||
pub enum ProviderResponseType {
|
||||
ChatCompletionsResponse(ChatCompletionsResponse),
|
||||
MessagesResponse(MessagesResponse),
|
||||
ResponsesAPIResponse(ResponsesAPIResponse),
|
||||
ResponsesAPIResponse(Box<ResponsesAPIResponse>),
|
||||
}
|
||||
|
||||
/// Trait for token usage information
|
||||
|
|
@ -42,7 +41,9 @@ impl ProviderResponse for ProviderResponseType {
|
|||
match self {
|
||||
ProviderResponseType::ChatCompletionsResponse(resp) => resp.usage(),
|
||||
ProviderResponseType::MessagesResponse(resp) => resp.usage(),
|
||||
ProviderResponseType::ResponsesAPIResponse(resp) => resp.usage.as_ref().map(|u| u as &dyn TokenUsage),
|
||||
ProviderResponseType::ResponsesAPIResponse(resp) => {
|
||||
resp.usage.as_ref().map(|u| u as &dyn TokenUsage)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -50,11 +51,13 @@ impl ProviderResponse for ProviderResponseType {
|
|||
match self {
|
||||
ProviderResponseType::ChatCompletionsResponse(resp) => resp.extract_usage_counts(),
|
||||
ProviderResponseType::MessagesResponse(resp) => resp.extract_usage_counts(),
|
||||
ProviderResponseType::ResponsesAPIResponse(resp) => {
|
||||
resp.usage.as_ref().map(|u| {
|
||||
(u.input_tokens as usize, u.output_tokens as usize, u.total_tokens as usize)
|
||||
})
|
||||
}
|
||||
ProviderResponseType::ResponsesAPIResponse(resp) => resp.usage.as_ref().map(|u| {
|
||||
(
|
||||
u.input_tokens as usize,
|
||||
u.output_tokens as usize,
|
||||
u.total_tokens as usize,
|
||||
)
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -156,40 +159,44 @@ impl TryFrom<(&[u8], &SupportedAPIsFromClient, &ProviderId)> for ProviderRespons
|
|||
) => {
|
||||
let resp: ResponsesAPIResponse = ResponsesAPIResponse::try_from(bytes)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
|
||||
Ok(ProviderResponseType::ResponsesAPIResponse(resp))
|
||||
Ok(ProviderResponseType::ResponsesAPIResponse(Box::new(resp)))
|
||||
}
|
||||
(
|
||||
SupportedUpstreamAPIs::OpenAIChatCompletions(_),
|
||||
SupportedAPIsFromClient::OpenAIResponsesAPI(_),
|
||||
) => {
|
||||
let chat_completions_response: ChatCompletionsResponse = ChatCompletionsResponse::try_from(bytes)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
|
||||
let chat_completions_response: ChatCompletionsResponse =
|
||||
ChatCompletionsResponse::try_from(bytes)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
|
||||
|
||||
// Transform to ResponsesAPI format using the transformer
|
||||
let responses_resp: ResponsesAPIResponse = chat_completions_response.try_into().map_err(|e| {
|
||||
std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidData,
|
||||
format!("Transformation error: {}", e),
|
||||
)
|
||||
})?;
|
||||
Ok(ProviderResponseType::ResponsesAPIResponse(responses_resp))
|
||||
let responses_resp: ResponsesAPIResponse =
|
||||
chat_completions_response.try_into().map_err(|e| {
|
||||
std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidData,
|
||||
format!("Transformation error: {}", e),
|
||||
)
|
||||
})?;
|
||||
Ok(ProviderResponseType::ResponsesAPIResponse(Box::new(
|
||||
responses_resp,
|
||||
)))
|
||||
}
|
||||
(
|
||||
SupportedUpstreamAPIs::AnthropicMessagesAPI(_),
|
||||
SupportedAPIsFromClient::OpenAIResponsesAPI(_),
|
||||
) => {
|
||||
|
||||
//Chain transform: Anthropic Messages -> OpenAI ChatCompletions -> ResponsesAPI
|
||||
let anthropic_resp: MessagesResponse = serde_json::from_slice(bytes)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
|
||||
|
||||
// Transform to ChatCompletions format using the transformer
|
||||
let chat_resp: ChatCompletionsResponse = anthropic_resp.try_into().map_err(|e| {
|
||||
std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidData,
|
||||
format!("Transformation error: {}", e),
|
||||
)
|
||||
})?;
|
||||
let chat_resp: ChatCompletionsResponse =
|
||||
anthropic_resp.try_into().map_err(|e| {
|
||||
std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidData,
|
||||
format!("Transformation error: {}", e),
|
||||
)
|
||||
})?;
|
||||
|
||||
let response_api: ResponsesAPIResponse = chat_resp.try_into().map_err(|e| {
|
||||
std::io::Error::new(
|
||||
|
|
@ -197,7 +204,9 @@ impl TryFrom<(&[u8], &SupportedAPIsFromClient, &ProviderId)> for ProviderRespons
|
|||
format!("Transformation error: {}", e),
|
||||
)
|
||||
})?;
|
||||
Ok(ProviderResponseType::ResponsesAPIResponse(response_api))
|
||||
Ok(ProviderResponseType::ResponsesAPIResponse(Box::new(
|
||||
response_api,
|
||||
)))
|
||||
}
|
||||
(
|
||||
SupportedUpstreamAPIs::AmazonBedrockConverse(_),
|
||||
|
|
@ -219,10 +228,15 @@ impl TryFrom<(&[u8], &SupportedAPIsFromClient, &ProviderId)> for ProviderRespons
|
|||
let response_api: ResponsesAPIResponse = chat_resp.try_into().map_err(|e| {
|
||||
std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidData,
|
||||
format!("ChatCompletions to ResponsesAPI transformation error: {}", e),
|
||||
format!(
|
||||
"ChatCompletions to ResponsesAPI transformation error: {}",
|
||||
e
|
||||
),
|
||||
)
|
||||
})?;
|
||||
Ok(ProviderResponseType::ResponsesAPIResponse(response_api))
|
||||
Ok(ProviderResponseType::ResponsesAPIResponse(Box::new(
|
||||
response_api,
|
||||
)))
|
||||
}
|
||||
_ => Err(std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidData,
|
||||
|
|
@ -255,8 +269,8 @@ impl Error for ProviderResponseError {
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::apis::openai::OpenAIApi;
|
||||
use crate::apis::anthropic::AnthropicApi;
|
||||
use crate::apis::openai::OpenAIApi;
|
||||
use crate::clients::endpoints::SupportedAPIsFromClient;
|
||||
use crate::providers::id::ProviderId;
|
||||
use serde_json::json;
|
||||
|
|
|
|||
|
|
@ -1,18 +1,17 @@
|
|||
use serde::Serialize;
|
||||
use std::convert::TryFrom;
|
||||
|
||||
use crate::apis::amazon_bedrock::ConverseStreamEvent;
|
||||
use crate::apis::anthropic::MessagesStreamEvent;
|
||||
use crate::apis::openai::ChatCompletionsStreamResponse;
|
||||
use crate::apis::openai_responses::ResponsesAPIStreamEvent;
|
||||
use crate::apis::streaming_shapes::sse::SseEvent;
|
||||
use crate::apis::amazon_bedrock::ConverseStreamEvent;
|
||||
use crate::apis::anthropic::MessagesStreamEvent;
|
||||
use crate::apis::streaming_shapes::sse::SseStreamBuffer;
|
||||
use crate::apis::streaming_shapes::{
|
||||
anthropic_streaming_buffer::AnthropicMessagesStreamBuffer,
|
||||
chat_completions_streaming_buffer::OpenAIChatCompletionsStreamBuffer,
|
||||
passthrough_streaming_buffer::PassthroughStreamBuffer,
|
||||
responses_api_streaming_buffer::ResponsesAPIStreamBuffer,
|
||||
};
|
||||
anthropic_streaming_buffer::AnthropicMessagesStreamBuffer,
|
||||
chat_completions_streaming_buffer::OpenAIChatCompletionsStreamBuffer,
|
||||
passthrough_streaming_buffer::PassthroughStreamBuffer,
|
||||
};
|
||||
|
||||
use crate::clients::endpoints::SupportedAPIsFromClient;
|
||||
use crate::clients::endpoints::SupportedUpstreamAPIs;
|
||||
|
|
@ -28,9 +27,18 @@ pub fn needs_buffering(
|
|||
) -> bool {
|
||||
match (client_api, upstream_api) {
|
||||
// Same APIs - no buffering needed
|
||||
(SupportedAPIsFromClient::OpenAIChatCompletions(_), SupportedUpstreamAPIs::OpenAIChatCompletions(_)) => false,
|
||||
(SupportedAPIsFromClient::AnthropicMessagesAPI(_), SupportedUpstreamAPIs::AnthropicMessagesAPI(_)) => false,
|
||||
(SupportedAPIsFromClient::OpenAIResponsesAPI(_), SupportedUpstreamAPIs::OpenAIResponsesAPI(_)) => false,
|
||||
(
|
||||
SupportedAPIsFromClient::OpenAIChatCompletions(_),
|
||||
SupportedUpstreamAPIs::OpenAIChatCompletions(_),
|
||||
) => false,
|
||||
(
|
||||
SupportedAPIsFromClient::AnthropicMessagesAPI(_),
|
||||
SupportedUpstreamAPIs::AnthropicMessagesAPI(_),
|
||||
) => false,
|
||||
(
|
||||
SupportedAPIsFromClient::OpenAIResponsesAPI(_),
|
||||
SupportedUpstreamAPIs::OpenAIResponsesAPI(_),
|
||||
) => false,
|
||||
|
||||
// Different APIs - buffering needed
|
||||
_ => true,
|
||||
|
|
@ -53,15 +61,12 @@ pub fn needs_buffering(
|
|||
/// // Flush to wire
|
||||
/// let bytes = buffer.into_bytes();
|
||||
/// ```
|
||||
impl TryFrom<(&SupportedAPIsFromClient, &SupportedUpstreamAPIs)>
|
||||
for SseStreamBuffer
|
||||
{
|
||||
impl TryFrom<(&SupportedAPIsFromClient, &SupportedUpstreamAPIs)> for SseStreamBuffer {
|
||||
type Error = Box<dyn std::error::Error + Send + Sync>;
|
||||
|
||||
fn try_from(
|
||||
(client_api, upstream_api): (&SupportedAPIsFromClient, &SupportedUpstreamAPIs),
|
||||
) -> Result<Self, Self::Error> {
|
||||
|
||||
// If APIs match, use passthrough - no buffering/transformation needed
|
||||
if !needs_buffering(client_api, upstream_api) {
|
||||
return Ok(SseStreamBuffer::Passthrough(PassthroughStreamBuffer::new()));
|
||||
|
|
@ -69,14 +74,14 @@ impl TryFrom<(&SupportedAPIsFromClient, &SupportedUpstreamAPIs)>
|
|||
|
||||
// APIs differ - use appropriate buffer for client API
|
||||
match client_api {
|
||||
SupportedAPIsFromClient::OpenAIChatCompletions(_) => {
|
||||
Ok(SseStreamBuffer::OpenAIChatCompletions(OpenAIChatCompletionsStreamBuffer::new()))
|
||||
}
|
||||
SupportedAPIsFromClient::AnthropicMessagesAPI(_) => {
|
||||
Ok(SseStreamBuffer::AnthropicMessages(AnthropicMessagesStreamBuffer::new()))
|
||||
}
|
||||
SupportedAPIsFromClient::OpenAIChatCompletions(_) => Ok(
|
||||
SseStreamBuffer::OpenAIChatCompletions(OpenAIChatCompletionsStreamBuffer::new()),
|
||||
),
|
||||
SupportedAPIsFromClient::AnthropicMessagesAPI(_) => Ok(
|
||||
SseStreamBuffer::AnthropicMessages(AnthropicMessagesStreamBuffer::new()),
|
||||
),
|
||||
SupportedAPIsFromClient::OpenAIResponsesAPI(_) => {
|
||||
Ok(SseStreamBuffer::OpenAIResponses(ResponsesAPIStreamBuffer::new()))
|
||||
Ok(SseStreamBuffer::OpenAIResponses(Box::default()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -88,11 +93,12 @@ impl TryFrom<(&SupportedAPIsFromClient, &SupportedUpstreamAPIs)>
|
|||
|
||||
#[derive(Serialize, Debug, Clone)]
|
||||
#[serde(untagged)]
|
||||
#[allow(clippy::large_enum_variant)]
|
||||
pub enum ProviderStreamResponseType {
|
||||
ChatCompletionsStreamResponse(ChatCompletionsStreamResponse),
|
||||
MessagesStreamEvent(MessagesStreamEvent),
|
||||
ConverseStreamEvent(ConverseStreamEvent),
|
||||
ResponseAPIStreamEvent(ResponsesAPIStreamEvent)
|
||||
ResponseAPIStreamEvent(Box<ResponsesAPIStreamEvent>),
|
||||
}
|
||||
|
||||
pub trait ProviderStreamResponse: Send + Sync {
|
||||
|
|
@ -145,12 +151,11 @@ impl ProviderStreamResponse for ProviderStreamResponseType {
|
|||
ProviderStreamResponseType::ResponseAPIStreamEvent(resp) => resp.event_type(),
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
impl Into<String> for ProviderStreamResponseType {
|
||||
fn into(self) -> String {
|
||||
match self {
|
||||
impl From<ProviderStreamResponseType> for String {
|
||||
fn from(val: ProviderStreamResponseType) -> String {
|
||||
match val {
|
||||
ProviderStreamResponseType::MessagesStreamEvent(event) => {
|
||||
// Use the Into<String> implementation for proper SSE formatting with event lines
|
||||
event.into()
|
||||
|
|
@ -161,27 +166,36 @@ impl Into<String> for ProviderStreamResponseType {
|
|||
}
|
||||
ProviderStreamResponseType::ResponseAPIStreamEvent(event) => {
|
||||
// Use the Into<String> implementation for proper SSE formatting with event lines
|
||||
event.into()
|
||||
// Clone to work around Box<T> ownership
|
||||
let cloned = (*event).clone();
|
||||
cloned.into()
|
||||
}
|
||||
ProviderStreamResponseType::ChatCompletionsStreamResponse(_) => {
|
||||
// For OpenAI, use simple data line format
|
||||
let json = serde_json::to_string(&self).unwrap_or_default();
|
||||
let json = serde_json::to_string(&val).unwrap_or_default();
|
||||
format!("data: {}\n\n", json)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Stream response transformation logic for client API compatibility
|
||||
impl TryFrom<(&[u8], &SupportedAPIsFromClient, &SupportedUpstreamAPIs)> for ProviderStreamResponseType {
|
||||
impl TryFrom<(&[u8], &SupportedAPIsFromClient, &SupportedUpstreamAPIs)>
|
||||
for ProviderStreamResponseType
|
||||
{
|
||||
type Error = Box<dyn std::error::Error + Send + Sync>;
|
||||
|
||||
fn try_from(
|
||||
(bytes, client_api, upstream_api): (&[u8], &SupportedAPIsFromClient, &SupportedUpstreamAPIs),
|
||||
(bytes, client_api, upstream_api): (
|
||||
&[u8],
|
||||
&SupportedAPIsFromClient,
|
||||
&SupportedUpstreamAPIs,
|
||||
),
|
||||
) -> Result<Self, Self::Error> {
|
||||
// Special case: Handle [DONE] marker for OpenAI -> Anthropic conversion
|
||||
if bytes == b"[DONE]" && matches!(client_api, SupportedAPIsFromClient::AnthropicMessagesAPI(_)) {
|
||||
if bytes == b"[DONE]"
|
||||
&& matches!(client_api, SupportedAPIsFromClient::AnthropicMessagesAPI(_))
|
||||
{
|
||||
return Ok(ProviderStreamResponseType::MessagesStreamEvent(
|
||||
crate::apis::anthropic::MessagesStreamEvent::MessageStop,
|
||||
));
|
||||
|
|
@ -214,9 +228,9 @@ impl TryFrom<(&[u8], &SupportedAPIsFromClient, &SupportedUpstreamAPIs)> for Prov
|
|||
) => {
|
||||
let openai_resp: crate::apis::openai::ChatCompletionsStreamResponse =
|
||||
serde_json::from_slice(bytes)?;
|
||||
let responses_resp = openai_resp.try_into()?;
|
||||
let responses_resp: ResponsesAPIStreamEvent = openai_resp.try_into()?;
|
||||
Ok(ProviderStreamResponseType::ResponseAPIStreamEvent(
|
||||
responses_resp,
|
||||
Box::new(responses_resp),
|
||||
))
|
||||
}
|
||||
|
||||
|
|
@ -267,10 +281,11 @@ impl TryFrom<(&[u8], &SupportedAPIsFromClient, &SupportedUpstreamAPIs)> for Prov
|
|||
// Chain: Bedrock -> ChatCompletions -> ResponsesAPI
|
||||
let bedrock_resp: crate::apis::amazon_bedrock::ConverseStreamEvent =
|
||||
serde_json::from_slice(bytes)?;
|
||||
let chat_resp: crate::apis::openai::ChatCompletionsStreamResponse = bedrock_resp.try_into()?;
|
||||
let responses_resp = chat_resp.try_into()?;
|
||||
let chat_resp: crate::apis::openai::ChatCompletionsStreamResponse =
|
||||
bedrock_resp.try_into()?;
|
||||
let responses_resp: ResponsesAPIStreamEvent = chat_resp.try_into()?;
|
||||
Ok(ProviderStreamResponseType::ResponseAPIStreamEvent(
|
||||
responses_resp,
|
||||
Box::new(responses_resp),
|
||||
))
|
||||
}
|
||||
_ => Err(std::io::Error::new(
|
||||
|
|
@ -287,7 +302,11 @@ impl TryFrom<(SseEvent, &SupportedAPIsFromClient, &SupportedUpstreamAPIs)> for S
|
|||
type Error = Box<dyn std::error::Error + Send + Sync>;
|
||||
|
||||
fn try_from(
|
||||
(sse_event, client_api, upstream_api): (SseEvent, &SupportedAPIsFromClient, &SupportedUpstreamAPIs),
|
||||
(sse_event, client_api, upstream_api): (
|
||||
SseEvent,
|
||||
&SupportedAPIsFromClient,
|
||||
&SupportedUpstreamAPIs,
|
||||
),
|
||||
) -> Result<Self, Self::Error> {
|
||||
// Create a new transformed event based on the original
|
||||
let mut transformed_event = sse_event;
|
||||
|
|
@ -296,7 +315,11 @@ impl TryFrom<(SseEvent, &SupportedAPIsFromClient, &SupportedUpstreamAPIs)> for S
|
|||
if transformed_event.is_done() {
|
||||
// For OpenAI client APIs (ChatCompletions and ResponsesAPI), keep [DONE] as-is
|
||||
// For Anthropic client API, it will be transformed via ProviderStreamResponseType
|
||||
if matches!(client_api, SupportedAPIsFromClient::OpenAIChatCompletions(_) | SupportedAPIsFromClient::OpenAIResponsesAPI(_)) {
|
||||
if matches!(
|
||||
client_api,
|
||||
SupportedAPIsFromClient::OpenAIChatCompletions(_)
|
||||
| SupportedAPIsFromClient::OpenAIResponsesAPI(_)
|
||||
) {
|
||||
// Keep the [DONE] marker as-is for OpenAI clients
|
||||
transformed_event.sse_transformed_lines = "data: [DONE]".to_string();
|
||||
return Ok(transformed_event);
|
||||
|
|
@ -328,7 +351,7 @@ impl TryFrom<(SseEvent, &SupportedAPIsFromClient, &SupportedUpstreamAPIs)> for S
|
|||
// OpenAI clients don't expect separate event: lines
|
||||
// Suppress upstream Anthropic event-only lines
|
||||
if transformed_event.is_event_only() && transformed_event.event.is_some() {
|
||||
transformed_event.sse_transformed_lines = format!("\n");
|
||||
transformed_event.sse_transformed_lines = "\n".to_string();
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
|
|
@ -345,7 +368,8 @@ impl TryFrom<(SseEvent, &SupportedAPIsFromClient, &SupportedUpstreamAPIs)> for S
|
|||
(
|
||||
SupportedAPIsFromClient::AnthropicMessagesAPI(_),
|
||||
SupportedUpstreamAPIs::AnthropicMessagesAPI(_),
|
||||
) | (
|
||||
)
|
||||
| (
|
||||
SupportedAPIsFromClient::OpenAIResponsesAPI(_),
|
||||
SupportedUpstreamAPIs::OpenAIResponsesAPI(_),
|
||||
) => {
|
||||
|
|
@ -415,7 +439,7 @@ impl
|
|||
openai_event,
|
||||
))
|
||||
}
|
||||
(
|
||||
(
|
||||
SupportedUpstreamAPIs::AmazonBedrockConverseStream(_),
|
||||
SupportedAPIsFromClient::OpenAIResponsesAPI(_),
|
||||
) => {
|
||||
|
|
@ -428,7 +452,7 @@ impl
|
|||
openai_chat_completions_event.try_into()?;
|
||||
|
||||
Ok(ProviderStreamResponseType::ResponseAPIStreamEvent(
|
||||
openai_responses_api_event,
|
||||
Box::new(openai_responses_api_event),
|
||||
))
|
||||
}
|
||||
_ => Err("Unsupported API combination for event-stream decoding".into()),
|
||||
|
|
@ -445,11 +469,11 @@ impl
|
|||
mod tests {
|
||||
use super::*;
|
||||
use crate::apis::streaming_shapes::amazon_bedrock_binary_frame::BedrockBinaryFrameDecoder;
|
||||
use crate::clients::endpoints::SupportedAPIsFromClient;
|
||||
use crate::apis::streaming_shapes::sse::SseStreamIter;
|
||||
use crate::clients::endpoints::SupportedAPIsFromClient;
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
#[test]
|
||||
fn test_sse_event_parsing() {
|
||||
// Test valid SSE data line
|
||||
let line = "data: {\"id\":\"test\",\"object\":\"chat.completion.chunk\"}\n\n";
|
||||
|
|
@ -792,7 +816,7 @@ mod tests {
|
|||
// Simulate chunked network arrivals with realistic chunk sizes
|
||||
// Using varying chunk sizes to test partial frame handling
|
||||
let mut buffer = BytesMut::new();
|
||||
let chunk_size_pattern = vec![500, 1000, 750, 1200, 800, 1500];
|
||||
let chunk_size_pattern = [500, 1000, 750, 1200, 800, 1500];
|
||||
let mut offset = 0;
|
||||
let mut total_frames = 0;
|
||||
let mut chunk_num = 0;
|
||||
|
|
@ -837,7 +861,7 @@ mod tests {
|
|||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[test]
|
||||
fn test_bedrock_decoded_frame_to_provider_response() {
|
||||
test_bedrock_conversion(false);
|
||||
}
|
||||
|
|
@ -879,8 +903,9 @@ mod tests {
|
|||
|
||||
let mut decoder = BedrockBinaryFrameDecoder::new(&mut buffer);
|
||||
|
||||
let client_api =
|
||||
SupportedAPIsFromClient::AnthropicMessagesAPI(crate::apis::anthropic::AnthropicApi::Messages);
|
||||
let client_api = SupportedAPIsFromClient::AnthropicMessagesAPI(
|
||||
crate::apis::anthropic::AnthropicApi::Messages,
|
||||
);
|
||||
let upstream_api = SupportedUpstreamAPIs::AmazonBedrockConverseStream(
|
||||
crate::apis::amazon_bedrock::AmazonBedrockApi::ConverseStream,
|
||||
);
|
||||
|
|
@ -966,8 +991,9 @@ mod tests {
|
|||
|
||||
let mut decoder = BedrockBinaryFrameDecoder::new(&mut buffer);
|
||||
|
||||
let client_api =
|
||||
SupportedAPIsFromClient::AnthropicMessagesAPI(crate::apis::anthropic::AnthropicApi::Messages);
|
||||
let client_api = SupportedAPIsFromClient::AnthropicMessagesAPI(
|
||||
crate::apis::anthropic::AnthropicApi::Messages,
|
||||
);
|
||||
let upstream_api = SupportedUpstreamAPIs::AmazonBedrockConverseStream(
|
||||
crate::apis::amazon_bedrock::AmazonBedrockApi::ConverseStream,
|
||||
);
|
||||
|
|
@ -1051,7 +1077,6 @@ mod tests {
|
|||
);
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_sse_event_transformation_openai_to_anthropic_message_delta() {
|
||||
use crate::apis::anthropic::AnthropicApi;
|
||||
|
|
@ -1079,8 +1104,8 @@ mod tests {
|
|||
let sse_event = SseEvent {
|
||||
data: Some(openai_stream_chunk.to_string()),
|
||||
event: None,
|
||||
raw_line: format!("data: {}", openai_stream_chunk.to_string()),
|
||||
sse_transformed_lines: format!("data: {}", openai_stream_chunk.to_string()),
|
||||
raw_line: format!("data: {}", openai_stream_chunk),
|
||||
sse_transformed_lines: format!("data: {}", openai_stream_chunk),
|
||||
provider_stream_response: None,
|
||||
};
|
||||
|
||||
|
|
@ -1101,7 +1126,8 @@ mod tests {
|
|||
// Verify the event was transformed to Anthropic format
|
||||
// This should contain message_delta with stop_reason and usage
|
||||
assert!(
|
||||
buffer.contains("event: message_delta") || buffer.contains("\"type\":\"message_delta\""),
|
||||
buffer.contains("event: message_delta")
|
||||
|| buffer.contains("\"type\":\"message_delta\""),
|
||||
"Should contain message_delta in transformed event"
|
||||
);
|
||||
|
||||
|
|
@ -1134,8 +1160,8 @@ mod tests {
|
|||
let sse_event = SseEvent {
|
||||
data: Some(openai_stream_chunk.to_string()),
|
||||
event: None,
|
||||
raw_line: format!("data: {}", openai_stream_chunk.to_string()),
|
||||
sse_transformed_lines: format!("data: {}", openai_stream_chunk.to_string()),
|
||||
raw_line: format!("data: {}", openai_stream_chunk),
|
||||
sse_transformed_lines: format!("data: {}", openai_stream_chunk),
|
||||
provider_stream_response: None,
|
||||
};
|
||||
|
||||
|
|
@ -1223,8 +1249,8 @@ mod tests {
|
|||
let sse_event = SseEvent {
|
||||
data: Some(anthropic_event.to_string()),
|
||||
event: None,
|
||||
raw_line: format!("data: {}", anthropic_event.to_string()),
|
||||
sse_transformed_lines: format!("data: {}", anthropic_event.to_string()),
|
||||
raw_line: format!("data: {}", anthropic_event),
|
||||
sse_transformed_lines: format!("data: {}", anthropic_event),
|
||||
provider_stream_response: None,
|
||||
};
|
||||
|
||||
|
|
@ -1314,8 +1340,8 @@ mod tests {
|
|||
let sse_event = SseEvent {
|
||||
data: Some(openai_stream_chunk.to_string()),
|
||||
event: None,
|
||||
raw_line: format!("data: {}", openai_stream_chunk.to_string()),
|
||||
sse_transformed_lines: format!("data: {}", openai_stream_chunk.to_string()),
|
||||
raw_line: format!("data: {}", openai_stream_chunk),
|
||||
sse_transformed_lines: format!("data: {}", openai_stream_chunk),
|
||||
provider_stream_response: None,
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -11,11 +11,11 @@ pub trait ExtractText {
|
|||
/// Trait for utility functions on content collections
|
||||
pub trait ContentUtils<T> {
|
||||
fn extract_tool_calls(&self) -> Result<Option<Vec<ToolCall>>, TransformError>;
|
||||
fn split_for_openai(
|
||||
&self,
|
||||
) -> Result<(Vec<ContentPart>, Vec<ToolCall>, Vec<(String, String, bool)>), TransformError>;
|
||||
fn split_for_openai(&self) -> Result<SplitForOpenAIResult, TransformError>;
|
||||
}
|
||||
|
||||
pub type SplitForOpenAIResult = (Vec<ContentPart>, Vec<ToolCall>, Vec<(String, String, bool)>);
|
||||
|
||||
/// Helper to create a current unix timestamp
|
||||
pub fn current_timestamp() -> u64 {
|
||||
SystemTime::now()
|
||||
|
|
|
|||
|
|
@ -38,7 +38,7 @@ impl TryFrom<AnthropicMessagesRequest> for ChatCompletionsRequest {
|
|||
}
|
||||
|
||||
// Convert tools and tool choice
|
||||
let openai_tools = req.tools.map(|tools| convert_anthropic_tools(tools));
|
||||
let openai_tools = req.tools.map(convert_anthropic_tools);
|
||||
let (openai_tool_choice, parallel_tool_calls) =
|
||||
convert_anthropic_tool_choice(req.tool_choice);
|
||||
|
||||
|
|
@ -218,18 +218,18 @@ impl TryFrom<MessagesMessage> for Vec<Message> {
|
|||
}
|
||||
|
||||
// Role Conversions
|
||||
impl Into<Role> for MessagesRole {
|
||||
fn into(self) -> Role {
|
||||
match self {
|
||||
impl From<MessagesRole> for Role {
|
||||
fn from(val: MessagesRole) -> Self {
|
||||
match val {
|
||||
MessagesRole::User => Role::User,
|
||||
MessagesRole::Assistant => Role::Assistant,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Into<MessagesStopReason> for FinishReason {
|
||||
fn into(self) -> MessagesStopReason {
|
||||
match self {
|
||||
impl From<FinishReason> for MessagesStopReason {
|
||||
fn from(val: FinishReason) -> Self {
|
||||
match val {
|
||||
FinishReason::Stop => MessagesStopReason::EndTurn,
|
||||
FinishReason::Length => MessagesStopReason::MaxTokens,
|
||||
FinishReason::ToolCalls => MessagesStopReason::ToolUse,
|
||||
|
|
@ -239,11 +239,11 @@ impl Into<MessagesStopReason> for FinishReason {
|
|||
}
|
||||
}
|
||||
|
||||
impl Into<MessagesUsage> for Usage {
|
||||
fn into(self) -> MessagesUsage {
|
||||
impl From<Usage> for MessagesUsage {
|
||||
fn from(val: Usage) -> Self {
|
||||
MessagesUsage {
|
||||
input_tokens: self.prompt_tokens,
|
||||
output_tokens: self.completion_tokens,
|
||||
input_tokens: val.prompt_tokens,
|
||||
output_tokens: val.completion_tokens,
|
||||
cache_creation_input_tokens: None,
|
||||
cache_read_input_tokens: None,
|
||||
}
|
||||
|
|
@ -251,9 +251,9 @@ impl Into<MessagesUsage> for Usage {
|
|||
}
|
||||
|
||||
// System Prompt Conversions
|
||||
impl Into<Message> for MessagesSystemPrompt {
|
||||
fn into(self) -> Message {
|
||||
let system_content = match self {
|
||||
impl From<MessagesSystemPrompt> for Message {
|
||||
fn from(val: MessagesSystemPrompt) -> Self {
|
||||
let system_content = match val {
|
||||
MessagesSystemPrompt::Single(text) => MessageContent::Text(text),
|
||||
MessagesSystemPrompt::Blocks(blocks) => MessageContent::Text(blocks.extract_text()),
|
||||
};
|
||||
|
|
@ -384,12 +384,8 @@ impl TryFrom<MessagesMessage> for BedrockMessage {
|
|||
ToolResultContent::Blocks(blocks) => {
|
||||
let mut result_blocks = Vec::new();
|
||||
for result_block in blocks {
|
||||
match result_block {
|
||||
crate::apis::anthropic::MessagesContentBlock::Text { text, .. } => {
|
||||
result_blocks.push(ToolResultContentBlock::Text { text });
|
||||
}
|
||||
// For now, skip other content types in tool results
|
||||
_ => {}
|
||||
if let crate::apis::anthropic::MessagesContentBlock::Text { text, .. } = result_block {
|
||||
result_blocks.push(ToolResultContentBlock::Text { text });
|
||||
}
|
||||
}
|
||||
result_blocks
|
||||
|
|
|
|||
|
|
@ -14,7 +14,8 @@ use crate::apis::openai::{
|
|||
};
|
||||
|
||||
use crate::apis::openai_responses::{
|
||||
ResponsesAPIRequest, InputContent, InputItem, InputParam, MessageRole, Modality, ReasoningEffort, Tool as ResponsesTool, ToolChoice as ResponsesToolChoice
|
||||
InputContent, InputItem, InputParam, MessageRole, Modality, ReasoningEffort,
|
||||
ResponsesAPIRequest, Tool as ResponsesTool, ToolChoice as ResponsesToolChoice,
|
||||
};
|
||||
use crate::clients::TransformError;
|
||||
use crate::transforms::lib::ExtractText;
|
||||
|
|
@ -27,9 +28,9 @@ type AnthropicMessagesRequest = MessagesRequest;
|
|||
// MAIN REQUEST TRANSFORMATIONS
|
||||
// ============================================================================
|
||||
|
||||
impl Into<MessagesSystemPrompt> for Message {
|
||||
fn into(self) -> MessagesSystemPrompt {
|
||||
let system_text = match self.content {
|
||||
impl From<Message> for MessagesSystemPrompt {
|
||||
fn from(val: Message) -> Self {
|
||||
let system_text = match val.content {
|
||||
MessageContent::Text(text) => text,
|
||||
MessageContent::Parts(parts) => parts.extract_text(),
|
||||
};
|
||||
|
|
@ -163,7 +164,7 @@ impl TryFrom<Message> for BedrockMessage {
|
|||
let has_tool_calls = message
|
||||
.tool_calls
|
||||
.as_ref()
|
||||
.map_or(false, |calls| !calls.is_empty());
|
||||
.is_some_and(|calls| !calls.is_empty());
|
||||
|
||||
// Add text content if it's non-empty, or if we have no tool calls (to avoid empty content)
|
||||
if !text_content.is_empty() {
|
||||
|
|
@ -252,7 +253,6 @@ impl TryFrom<ResponsesAPIRequest> for ChatCompletionsRequest {
|
|||
type Error = TransformError;
|
||||
|
||||
fn try_from(req: ResponsesAPIRequest) -> Result<Self, Self::Error> {
|
||||
|
||||
// Convert input to messages
|
||||
let messages = match req.input {
|
||||
InputParam::Text(text) => {
|
||||
|
|
@ -282,50 +282,27 @@ impl TryFrom<ResponsesAPIRequest> for ChatCompletionsRequest {
|
|||
|
||||
// Convert each input item
|
||||
for item in items {
|
||||
match item {
|
||||
InputItem::Message(input_msg) => {
|
||||
let role = match input_msg.role {
|
||||
MessageRole::User => Role::User,
|
||||
MessageRole::Assistant => Role::Assistant,
|
||||
MessageRole::System => Role::System,
|
||||
MessageRole::Developer => Role::System, // Map developer to system
|
||||
};
|
||||
if let InputItem::Message(input_msg) = item {
|
||||
let role = match input_msg.role {
|
||||
MessageRole::User => Role::User,
|
||||
MessageRole::Assistant => Role::Assistant,
|
||||
MessageRole::System => Role::System,
|
||||
MessageRole::Developer => Role::System, // Map developer to system
|
||||
};
|
||||
|
||||
// Convert content based on MessageContent type
|
||||
let content = match &input_msg.content {
|
||||
crate::apis::openai_responses::MessageContent::Text(text) => {
|
||||
// Simple text content
|
||||
MessageContent::Text(text.clone())
|
||||
}
|
||||
crate::apis::openai_responses::MessageContent::Items(content_items) => {
|
||||
// Check if it's a single text item (can use simple text format)
|
||||
if content_items.len() == 1 {
|
||||
if let InputContent::InputText { text } = &content_items[0] {
|
||||
MessageContent::Text(text.clone())
|
||||
} else {
|
||||
// Single non-text item - use parts format
|
||||
MessageContent::Parts(
|
||||
content_items.iter()
|
||||
.filter_map(|c| match c {
|
||||
InputContent::InputText { text } => {
|
||||
Some(crate::apis::openai::ContentPart::Text { text: text.clone() })
|
||||
}
|
||||
InputContent::InputImage { image_url, .. } => {
|
||||
Some(crate::apis::openai::ContentPart::ImageUrl {
|
||||
image_url: crate::apis::openai::ImageUrl {
|
||||
url: image_url.clone(),
|
||||
detail: None,
|
||||
}
|
||||
})
|
||||
}
|
||||
InputContent::InputFile { .. } => None, // Skip files for now
|
||||
InputContent::InputAudio { .. } => None, // Skip audio for now
|
||||
})
|
||||
.collect()
|
||||
)
|
||||
}
|
||||
// Convert content based on MessageContent type
|
||||
let content = match &input_msg.content {
|
||||
crate::apis::openai_responses::MessageContent::Text(text) => {
|
||||
// Simple text content
|
||||
MessageContent::Text(text.clone())
|
||||
}
|
||||
crate::apis::openai_responses::MessageContent::Items(content_items) => {
|
||||
// Check if it's a single text item (can use simple text format)
|
||||
if content_items.len() == 1 {
|
||||
if let InputContent::InputText { text } = &content_items[0] {
|
||||
MessageContent::Text(text.clone())
|
||||
} else {
|
||||
// Multiple content items - convert to parts
|
||||
// Single non-text item - use parts format
|
||||
MessageContent::Parts(
|
||||
content_items.iter()
|
||||
.filter_map(|c| match c {
|
||||
|
|
@ -346,20 +323,41 @@ impl TryFrom<ResponsesAPIRequest> for ChatCompletionsRequest {
|
|||
.collect()
|
||||
)
|
||||
}
|
||||
} else {
|
||||
// Multiple content items - convert to parts
|
||||
MessageContent::Parts(
|
||||
content_items
|
||||
.iter()
|
||||
.filter_map(|c| match c {
|
||||
InputContent::InputText { text } => {
|
||||
Some(crate::apis::openai::ContentPart::Text {
|
||||
text: text.clone(),
|
||||
})
|
||||
}
|
||||
InputContent::InputImage { image_url, .. } => Some(
|
||||
crate::apis::openai::ContentPart::ImageUrl {
|
||||
image_url: crate::apis::openai::ImageUrl {
|
||||
url: image_url.clone(),
|
||||
detail: None,
|
||||
},
|
||||
},
|
||||
),
|
||||
InputContent::InputFile { .. } => None, // Skip files for now
|
||||
InputContent::InputAudio { .. } => None, // Skip audio for now
|
||||
})
|
||||
.collect(),
|
||||
)
|
||||
}
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
converted_messages.push(Message {
|
||||
role,
|
||||
content,
|
||||
name: None,
|
||||
tool_call_id: None,
|
||||
tool_calls: None,
|
||||
});
|
||||
}
|
||||
// Skip non-message items (references, outputs) for now
|
||||
// These would need special handling in chat completions format
|
||||
_ => {}
|
||||
converted_messages.push(Message {
|
||||
role,
|
||||
content,
|
||||
name: None,
|
||||
tool_call_id: None,
|
||||
tool_calls: None,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -474,7 +472,7 @@ impl TryFrom<ChatCompletionsRequest> for AnthropicMessagesRequest {
|
|||
}
|
||||
|
||||
// Convert tools and tool choice
|
||||
let anthropic_tools = req.tools.map(|tools| convert_openai_tools(tools));
|
||||
let anthropic_tools = req.tools.map(convert_openai_tools);
|
||||
let anthropic_tool_choice =
|
||||
convert_openai_tool_choice(req.tool_choice, req.parallel_tool_calls);
|
||||
|
||||
|
|
|
|||
|
|
@ -13,18 +13,14 @@ use crate::apis::openai_responses::{
|
|||
pub fn convert_responses_output_to_input_items(output: &OutputItem) -> Option<InputItem> {
|
||||
match output {
|
||||
// Convert output messages to input messages
|
||||
OutputItem::Message {
|
||||
role, content, ..
|
||||
} => {
|
||||
OutputItem::Message { role, content, .. } => {
|
||||
let input_content: Vec<InputContent> = content
|
||||
.iter()
|
||||
.filter_map(|c| match c {
|
||||
OutputContent::OutputText { text, .. } => Some(InputContent::InputText {
|
||||
text: text.clone(),
|
||||
}),
|
||||
OutputContent::OutputAudio {
|
||||
data, ..
|
||||
} => Some(InputContent::InputAudio {
|
||||
OutputContent::OutputText { text, .. } => {
|
||||
Some(InputContent::InputText { text: text.clone() })
|
||||
}
|
||||
OutputContent::OutputAudio { data, .. } => Some(InputContent::InputAudio {
|
||||
data: data.clone(),
|
||||
format: None, // Format not preserved in output
|
||||
}),
|
||||
|
|
@ -84,7 +80,7 @@ pub fn outputs_to_inputs(outputs: &[OutputItem]) -> Vec<InputItem> {
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::apis::openai_responses::{OutputItemStatus};
|
||||
use crate::apis::openai_responses::OutputItemStatus;
|
||||
|
||||
#[test]
|
||||
fn test_output_message_to_input() {
|
||||
|
|
@ -135,14 +131,12 @@ mod tests {
|
|||
InputItem::Message(msg) => {
|
||||
assert!(matches!(msg.role, MessageRole::Assistant));
|
||||
match &msg.content {
|
||||
MessageContent::Items(items) => {
|
||||
match &items[0] {
|
||||
InputContent::InputText { text } => {
|
||||
assert!(text.contains("get_weather"));
|
||||
}
|
||||
_ => panic!("Expected InputText"),
|
||||
MessageContent::Items(items) => match &items[0] {
|
||||
InputContent::InputText { text } => {
|
||||
assert!(text.contains("get_weather"));
|
||||
}
|
||||
}
|
||||
_ => panic!("Expected InputText"),
|
||||
},
|
||||
_ => panic!("Expected MessageContent::Items"),
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
use crate::apis::amazon_bedrock::{ConverseOutput, ConverseResponse, StopReason};
|
||||
use crate::apis::anthropic::{
|
||||
MessagesContentBlock, MessagesResponse,
|
||||
MessagesRole, MessagesStopReason, MessagesUsage,
|
||||
MessagesContentBlock, MessagesResponse, MessagesRole, MessagesStopReason, MessagesUsage,
|
||||
};
|
||||
use crate::apis::openai::ChatCompletionsResponse;
|
||||
use crate::clients::TransformError;
|
||||
|
|
@ -115,7 +114,6 @@ impl TryFrom<ConverseResponse> for MessagesResponse {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
/// Convert Bedrock Message to Anthropic content blocks
|
||||
///
|
||||
/// This function handles the conversion between Amazon Bedrock Converse API format
|
||||
|
|
|
|||
|
|
@ -1,9 +1,5 @@
|
|||
use crate::apis::amazon_bedrock::{
|
||||
ConverseOutput, ConverseResponse, StopReason,
|
||||
};
|
||||
use crate::apis::anthropic::{
|
||||
MessagesContentBlock, MessagesResponse, MessagesUsage,
|
||||
};
|
||||
use crate::apis::amazon_bedrock::{ConverseOutput, ConverseResponse, StopReason};
|
||||
use crate::apis::anthropic::{MessagesContentBlock, MessagesResponse, MessagesUsage};
|
||||
use crate::apis::openai::{
|
||||
ChatCompletionsResponse, Choice, FinishReason, MessageContent, ResponseMessage, Role, Usage,
|
||||
};
|
||||
|
|
@ -16,12 +12,12 @@ use crate::transforms::lib::*;
|
|||
// ============================================================================
|
||||
|
||||
// Usage Conversions
|
||||
impl Into<Usage> for MessagesUsage {
|
||||
fn into(self) -> Usage {
|
||||
impl From<MessagesUsage> for Usage {
|
||||
fn from(val: MessagesUsage) -> Self {
|
||||
Usage {
|
||||
prompt_tokens: self.input_tokens,
|
||||
completion_tokens: self.output_tokens,
|
||||
total_tokens: self.input_tokens + self.output_tokens,
|
||||
prompt_tokens: val.input_tokens,
|
||||
completion_tokens: val.output_tokens,
|
||||
total_tokens: val.input_tokens + val.output_tokens,
|
||||
prompt_tokens_details: None,
|
||||
completion_tokens_details: None,
|
||||
}
|
||||
|
|
@ -203,7 +199,6 @@ impl TryFrom<ChatCompletionsResponse> for ResponsesAPIResponse {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
impl TryFrom<MessagesResponse> for ChatCompletionsResponse {
|
||||
type Error = TransformError;
|
||||
|
||||
|
|
@ -415,7 +410,6 @@ fn convert_anthropic_content_to_openai(
|
|||
Ok(MessageContent::Text(text_parts.join("\n")))
|
||||
}
|
||||
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
|
@ -994,8 +988,15 @@ mod tests {
|
|||
let responses_api: ResponsesAPIResponse = chat_response.try_into().unwrap();
|
||||
|
||||
// Response ID should be generated with resp_ prefix
|
||||
assert!(responses_api.id.starts_with("resp_"), "Response ID should start with 'resp_'");
|
||||
assert_eq!(responses_api.id.len(), 37, "Response ID should be resp_ + 32 char UUID");
|
||||
assert!(
|
||||
responses_api.id.starts_with("resp_"),
|
||||
"Response ID should start with 'resp_'"
|
||||
);
|
||||
assert_eq!(
|
||||
responses_api.id.len(),
|
||||
37,
|
||||
"Response ID should be resp_ + 32 char UUID"
|
||||
);
|
||||
assert_eq!(responses_api.object, "response");
|
||||
assert_eq!(responses_api.model, "gpt-4");
|
||||
|
||||
|
|
@ -1008,11 +1009,7 @@ mod tests {
|
|||
// Check output items
|
||||
assert_eq!(responses_api.output.len(), 1);
|
||||
match &responses_api.output[0] {
|
||||
OutputItem::Message {
|
||||
role,
|
||||
content,
|
||||
..
|
||||
} => {
|
||||
OutputItem::Message { role, content, .. } => {
|
||||
assert_eq!(role, "assistant");
|
||||
assert_eq!(content.len(), 1);
|
||||
match &content[0] {
|
||||
|
|
@ -1163,6 +1160,9 @@ mod tests {
|
|||
}
|
||||
|
||||
// Verify status is Completed for tool_calls finish reason
|
||||
assert!(matches!(responses_api.status, crate::apis::openai_responses::ResponseStatus::Completed));
|
||||
assert!(matches!(
|
||||
responses_api.status,
|
||||
crate::apis::openai_responses::ResponseStatus::Completed
|
||||
));
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,12 +1,9 @@
|
|||
use crate::apis::amazon_bedrock::{
|
||||
ContentBlockDelta, ConverseStreamEvent,
|
||||
};
|
||||
use crate::apis::amazon_bedrock::{ContentBlockDelta, ConverseStreamEvent};
|
||||
use crate::apis::anthropic::{
|
||||
MessagesContentBlock, MessagesContentDelta, MessagesMessageDelta,
|
||||
MessagesRole, MessagesStopReason, MessagesStreamEvent, MessagesStreamMessage, MessagesUsage,
|
||||
};
|
||||
use crate::apis::openai::{ ChatCompletionsStreamResponse, ToolCallDelta,
|
||||
MessagesContentBlock, MessagesContentDelta, MessagesMessageDelta, MessagesRole,
|
||||
MessagesStopReason, MessagesStreamEvent, MessagesStreamMessage, MessagesUsage,
|
||||
};
|
||||
use crate::apis::openai::{ChatCompletionsStreamResponse, ToolCallDelta};
|
||||
use crate::clients::TransformError;
|
||||
use serde_json::Value;
|
||||
|
||||
|
|
@ -86,10 +83,10 @@ impl TryFrom<ChatCompletionsStreamResponse> for MessagesStreamEvent {
|
|||
}
|
||||
}
|
||||
|
||||
impl Into<String> for MessagesStreamEvent {
|
||||
fn into(self) -> String {
|
||||
let transformed_json = serde_json::to_string(&self).unwrap_or_default();
|
||||
let event_type = match &self {
|
||||
impl From<MessagesStreamEvent> for String {
|
||||
fn from(val: MessagesStreamEvent) -> Self {
|
||||
let transformed_json = serde_json::to_string(&val).unwrap_or_default();
|
||||
let event_type = match &val {
|
||||
MessagesStreamEvent::MessageStart { .. } => "message_start",
|
||||
MessagesStreamEvent::ContentBlockStart { .. } => "content_block_start",
|
||||
MessagesStreamEvent::ContentBlockDelta { .. } => "content_block_delta",
|
||||
|
|
@ -194,10 +191,18 @@ impl TryFrom<ConverseStreamEvent> for MessagesStreamEvent {
|
|||
let anthropic_stop_reason = match stop_event.stop_reason {
|
||||
crate::apis::amazon_bedrock::StopReason::EndTurn => MessagesStopReason::EndTurn,
|
||||
crate::apis::amazon_bedrock::StopReason::ToolUse => MessagesStopReason::ToolUse,
|
||||
crate::apis::amazon_bedrock::StopReason::MaxTokens => MessagesStopReason::MaxTokens,
|
||||
crate::apis::amazon_bedrock::StopReason::StopSequence => MessagesStopReason::EndTurn,
|
||||
crate::apis::amazon_bedrock::StopReason::GuardrailIntervened => MessagesStopReason::Refusal,
|
||||
crate::apis::amazon_bedrock::StopReason::ContentFiltered => MessagesStopReason::Refusal,
|
||||
crate::apis::amazon_bedrock::StopReason::MaxTokens => {
|
||||
MessagesStopReason::MaxTokens
|
||||
}
|
||||
crate::apis::amazon_bedrock::StopReason::StopSequence => {
|
||||
MessagesStopReason::EndTurn
|
||||
}
|
||||
crate::apis::amazon_bedrock::StopReason::GuardrailIntervened => {
|
||||
MessagesStopReason::Refusal
|
||||
}
|
||||
crate::apis::amazon_bedrock::StopReason::ContentFiltered => {
|
||||
MessagesStopReason::Refusal
|
||||
}
|
||||
};
|
||||
|
||||
Ok(MessagesStreamEvent::MessageDelta {
|
||||
|
|
|
|||
|
|
@ -1,8 +1,10 @@
|
|||
use crate::apis::amazon_bedrock::{ ConverseStreamEvent, StopReason};
|
||||
use crate::apis::amazon_bedrock::{ConverseStreamEvent, StopReason};
|
||||
use crate::apis::anthropic::{
|
||||
MessagesContentBlock, MessagesContentDelta, MessagesStopReason, MessagesStreamEvent};
|
||||
use crate::apis::openai::{ ChatCompletionsStreamResponse,FinishReason,
|
||||
FunctionCallDelta, MessageDelta, Role, StreamChoice, ToolCallDelta, Usage,
|
||||
MessagesContentBlock, MessagesContentDelta, MessagesStopReason, MessagesStreamEvent,
|
||||
};
|
||||
use crate::apis::openai::{
|
||||
ChatCompletionsStreamResponse, FinishReason, FunctionCallDelta, MessageDelta, Role,
|
||||
StreamChoice, ToolCallDelta, Usage,
|
||||
};
|
||||
use crate::apis::openai_responses::ResponsesAPIStreamEvent;
|
||||
|
||||
|
|
@ -58,11 +60,14 @@ impl TryFrom<MessagesStreamEvent> for ChatCompletionsStreamResponse {
|
|||
None,
|
||||
)),
|
||||
|
||||
MessagesStreamEvent::ContentBlockStart { content_block, index } => {
|
||||
convert_content_block_start(content_block, index)
|
||||
}
|
||||
MessagesStreamEvent::ContentBlockStart {
|
||||
content_block,
|
||||
index,
|
||||
} => convert_content_block_start(content_block, index),
|
||||
|
||||
MessagesStreamEvent::ContentBlockDelta { delta, index } => convert_content_delta(delta, index),
|
||||
MessagesStreamEvent::ContentBlockDelta { delta, index } => {
|
||||
convert_content_delta(delta, index)
|
||||
}
|
||||
|
||||
MessagesStreamEvent::ContentBlockStop { .. } => Ok(create_empty_openai_chunk()),
|
||||
|
||||
|
|
@ -427,9 +432,9 @@ fn create_empty_openai_chunk() -> ChatCompletionsStreamResponse {
|
|||
}
|
||||
|
||||
// Stop Reason Conversions
|
||||
impl Into<FinishReason> for MessagesStopReason {
|
||||
fn into(self) -> FinishReason {
|
||||
match self {
|
||||
impl From<MessagesStopReason> for FinishReason {
|
||||
fn from(val: MessagesStopReason) -> Self {
|
||||
match val {
|
||||
MessagesStopReason::EndTurn => FinishReason::Stop,
|
||||
MessagesStopReason::MaxTokens => FinishReason::Length,
|
||||
MessagesStopReason::StopSequence => FinishReason::Stop,
|
||||
|
|
@ -456,34 +461,37 @@ impl TryFrom<ChatCompletionsStreamResponse> for ResponsesAPIStreamEvent {
|
|||
if let Some(tool_call) = tool_calls.first() {
|
||||
// Extract call_id and name if available (metadata from initial event)
|
||||
let call_id = tool_call.id.clone();
|
||||
let function_name = tool_call.function.as_ref()
|
||||
.and_then(|f| f.name.clone());
|
||||
let function_name = tool_call.function.as_ref().and_then(|f| f.name.clone());
|
||||
|
||||
// Check if we have function metadata (name, id)
|
||||
if let Some(function) = &tool_call.function {
|
||||
// If we have arguments delta, return that
|
||||
if let Some(args) = &function.arguments {
|
||||
return Ok(ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDelta {
|
||||
output_index: choice.index as i32,
|
||||
item_id: "".to_string(), // Buffer will fill this
|
||||
delta: args.clone(),
|
||||
sequence_number: 0, // Buffer will fill this
|
||||
call_id,
|
||||
name: function_name,
|
||||
});
|
||||
return Ok(
|
||||
ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDelta {
|
||||
output_index: choice.index as i32,
|
||||
item_id: "".to_string(), // Buffer will fill this
|
||||
delta: args.clone(),
|
||||
sequence_number: 0, // Buffer will fill this
|
||||
call_id,
|
||||
name: function_name,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
// If we have function name but no arguments yet (initial tool call event)
|
||||
// Return an empty arguments delta so the buffer knows to create the item
|
||||
if function.name.is_some() {
|
||||
return Ok(ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDelta {
|
||||
output_index: choice.index as i32,
|
||||
item_id: "".to_string(), // Buffer will fill this
|
||||
delta: "".to_string(), // Empty delta signals this is the initial event
|
||||
sequence_number: 0, // Buffer will fill this
|
||||
call_id,
|
||||
name: function_name,
|
||||
});
|
||||
return Ok(
|
||||
ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDelta {
|
||||
output_index: choice.index as i32,
|
||||
item_id: "".to_string(), // Buffer will fill this
|
||||
delta: "".to_string(), // Empty delta signals this is the initial event
|
||||
sequence_number: 0, // Buffer will fill this
|
||||
call_id,
|
||||
name: function_name,
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -94,8 +94,8 @@ impl StreamContext {
|
|||
fn request_identifier(&self) -> String {
|
||||
self.request_id
|
||||
.as_ref()
|
||||
.filter(|id| !id.is_empty()) // Filter out empty strings
|
||||
.map(|id| id.clone())
|
||||
.filter(|id| !id.is_empty())
|
||||
.cloned()
|
||||
.unwrap_or_else(|| "NO_REQUEST_ID".to_string())
|
||||
}
|
||||
fn llm_provider(&self) -> &LlmProvider {
|
||||
|
|
@ -504,7 +504,7 @@ impl StreamContext {
|
|||
// Get accumulated bytes from buffer and return
|
||||
match self.sse_buffer.as_mut() {
|
||||
Some(buffer) => {
|
||||
let bytes = buffer.into_bytes();
|
||||
let bytes = buffer.to_bytes();
|
||||
if !bytes.is_empty() {
|
||||
let content = String::from_utf8_lossy(&bytes);
|
||||
debug!(
|
||||
|
|
@ -623,7 +623,7 @@ impl StreamContext {
|
|||
// Get accumulated bytes from buffer and return
|
||||
match self.sse_buffer.as_mut() {
|
||||
Some(buffer) => {
|
||||
let bytes = buffer.into_bytes();
|
||||
let bytes = buffer.to_bytes();
|
||||
if !bytes.is_empty() {
|
||||
let content = String::from_utf8_lossy(&bytes);
|
||||
debug!(
|
||||
|
|
|
|||
|
|
@ -142,8 +142,7 @@ impl HttpContext for StreamContext {
|
|||
let last_user_prompt = match deserialized_body
|
||||
.messages
|
||||
.iter()
|
||||
.filter(|msg| msg.role == USER_ROLE)
|
||||
.last()
|
||||
.rfind(|msg| msg.role == USER_ROLE)
|
||||
{
|
||||
Some(content) => content,
|
||||
None => {
|
||||
|
|
@ -155,11 +154,8 @@ impl HttpContext for StreamContext {
|
|||
self.user_prompt = Some(last_user_prompt.clone());
|
||||
|
||||
// convert prompt targets to ChatCompletionTool
|
||||
let tool_calls: Vec<ChatCompletionTool> = self
|
||||
.prompt_targets
|
||||
.iter()
|
||||
.map(|(_, pt)| pt.into())
|
||||
.collect();
|
||||
let tool_calls: Vec<ChatCompletionTool> =
|
||||
self.prompt_targets.values().map(|pt| pt.into()).collect();
|
||||
|
||||
let mut metadata = deserialized_body.metadata.clone();
|
||||
|
||||
|
|
|
|||
|
|
@ -376,21 +376,22 @@ impl StreamContext {
|
|||
|
||||
// Parse arguments JSON string into HashMap
|
||||
// Note: convert from serde_json::Value to serde_yaml::Value for compatibility
|
||||
let tool_params: Option<HashMap<String, serde_yaml::Value>> = match serde_json::from_str::<HashMap<String, serde_json::Value>>(tool_params_str) {
|
||||
Ok(json_params) => {
|
||||
let yaml_params: HashMap<String, serde_yaml::Value> = json_params
|
||||
.into_iter()
|
||||
.filter_map(|(k, v)| {
|
||||
serde_yaml::to_value(&v).ok().map(|yaml_v| (k, yaml_v))
|
||||
})
|
||||
.collect();
|
||||
Some(yaml_params)
|
||||
},
|
||||
Err(e) => {
|
||||
warn!("Failed to parse tool call arguments: {}", e);
|
||||
None
|
||||
}
|
||||
};
|
||||
let tool_params: Option<HashMap<String, serde_yaml::Value>> =
|
||||
match serde_json::from_str::<HashMap<String, serde_json::Value>>(tool_params_str) {
|
||||
Ok(json_params) => {
|
||||
let yaml_params: HashMap<String, serde_yaml::Value> = json_params
|
||||
.into_iter()
|
||||
.filter_map(|(k, v)| {
|
||||
serde_yaml::to_value(&v).ok().map(|yaml_v| (k, yaml_v))
|
||||
})
|
||||
.collect();
|
||||
Some(yaml_params)
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to parse tool call arguments: {}", e);
|
||||
None
|
||||
}
|
||||
};
|
||||
|
||||
let endpoint_details = prompt_target.endpoint.as_ref().unwrap();
|
||||
let endpoint_path: String = endpoint_details
|
||||
|
|
@ -629,10 +630,10 @@ impl StreamContext {
|
|||
}
|
||||
};
|
||||
|
||||
if system_prompt.is_some() {
|
||||
if let Some(system_prompt_text) = system_prompt {
|
||||
let system_prompt_message = Message {
|
||||
role: SYSTEM_ROLE.to_string(),
|
||||
content: Some(ContentType::Text(system_prompt.unwrap())),
|
||||
content: Some(ContentType::Text(system_prompt_text)),
|
||||
model: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue