mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
add provider specific request parsing in agents chat
This commit is contained in:
parent
54a3d45bbf
commit
3a2c1828ee
4 changed files with 571 additions and 37 deletions
|
|
@ -1,10 +1,13 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use bytes::Bytes;
|
||||
use hermesllm::apis::openai::ChatCompletionsRequest;
|
||||
use hermesllm::apis::OpenAIMessage;
|
||||
use hermesllm::clients::SupportedAPIsFromClient;
|
||||
use hermesllm::ProviderRequestType;
|
||||
use http_body_util::combinators::BoxBody;
|
||||
use http_body_util::BodyExt;
|
||||
use hyper::{Request, Response};
|
||||
use serde::ser::Error as SerError;
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
use super::agent_selector::{AgentSelectionError, AgentSelector};
|
||||
|
|
@ -35,7 +38,15 @@ pub async fn agent_chat(
|
|||
listeners: Arc<tokio::sync::RwLock<Vec<common::configuration::Listener>>>,
|
||||
trace_collector: Arc<common::traces::TraceCollector>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||
match handle_agent_chat(request, router_service, agents_list, listeners, trace_collector).await {
|
||||
match handle_agent_chat(
|
||||
request,
|
||||
router_service,
|
||||
agents_list,
|
||||
listeners,
|
||||
trace_collector,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(response) => Ok(response),
|
||||
Err(err) => {
|
||||
// Check if this is a client error from the pipeline that should be cascaded
|
||||
|
|
@ -134,6 +145,13 @@ async fn handle_agent_chat(
|
|||
info!("Handling request for listener: {}", listener.name);
|
||||
|
||||
// Parse request body
|
||||
let request_path = request
|
||||
.uri()
|
||||
.path()
|
||||
.to_string()
|
||||
.strip_prefix("/agents")
|
||||
.unwrap()
|
||||
.to_string();
|
||||
let request_headers = request.headers().clone();
|
||||
let chat_request_bytes = request.collect().await?.to_bytes();
|
||||
|
||||
|
|
@ -142,15 +160,36 @@ async fn handle_agent_chat(
|
|||
String::from_utf8_lossy(&chat_request_bytes)
|
||||
);
|
||||
|
||||
let chat_completions_request: ChatCompletionsRequest =
|
||||
serde_json::from_slice(&chat_request_bytes).map_err(|err| {
|
||||
warn!(
|
||||
"Failed to parse request body as ChatCompletionsRequest: {}",
|
||||
err
|
||||
);
|
||||
AgentFilterChainError::RequestParsing(err)
|
||||
// Determine the API type from the endpoint
|
||||
let api_type =
|
||||
SupportedAPIsFromClient::from_endpoint(request_path.as_str()).ok_or_else(|| {
|
||||
let err_msg = format!("Unsupported endpoint: {}", request_path);
|
||||
warn!("{}", err_msg);
|
||||
AgentFilterChainError::RequestParsing(serde_json::Error::custom(err_msg))
|
||||
})?;
|
||||
|
||||
let client_request = match ProviderRequestType::try_from((&chat_request_bytes[..], &api_type)) {
|
||||
Ok(request) => request,
|
||||
Err(err) => {
|
||||
warn!("Failed to parse request as ProviderRequestType: {}", err);
|
||||
let err_msg = format!("Failed to parse request: {}", err);
|
||||
return Err(AgentFilterChainError::RequestParsing(
|
||||
serde_json::Error::custom(err_msg),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
let message: Vec<OpenAIMessage> = client_request.get_message_history();
|
||||
|
||||
// let chat_completions_request: ChatCompletionsRequest =
|
||||
// serde_json::from_slice(&chat_request_bytes).map_err(|err| {
|
||||
// warn!(
|
||||
// "Failed to parse request body as ChatCompletionsRequest: {}",
|
||||
// err
|
||||
// );
|
||||
// AgentFilterChainError::RequestParsing(err)
|
||||
// })?;
|
||||
|
||||
// Extract trace parent for routing
|
||||
let trace_parent = request_headers
|
||||
.iter()
|
||||
|
|
@ -166,11 +205,7 @@ async fn handle_agent_chat(
|
|||
|
||||
// Select appropriate agent using arch router llm model
|
||||
let selected_agent = agent_selector
|
||||
.select_agent(
|
||||
&chat_completions_request.messages,
|
||||
&listener,
|
||||
trace_parent,
|
||||
)
|
||||
.select_agent(&message, &listener, trace_parent)
|
||||
.await?;
|
||||
|
||||
debug!("Processing agent pipeline: {}", selected_agent.id);
|
||||
|
|
@ -178,7 +213,7 @@ async fn handle_agent_chat(
|
|||
// Process the filter chain
|
||||
let chat_history = pipeline_processor
|
||||
.process_filter_chain(
|
||||
&chat_completions_request.messages,
|
||||
&message,
|
||||
&selected_agent,
|
||||
&agent_map,
|
||||
&request_headers,
|
||||
|
|
@ -196,7 +231,7 @@ async fn handle_agent_chat(
|
|||
let llm_response = pipeline_processor
|
||||
.invoke_terminal_agent(
|
||||
&chat_history,
|
||||
&chat_completions_request,
|
||||
client_request,
|
||||
terminal_agent,
|
||||
&request_headers,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -3,7 +3,8 @@ use std::collections::HashMap;
|
|||
use common::configuration::{Agent, AgentFilterChain};
|
||||
use common::consts::{ARCH_UPSTREAM_HOST_HEADER, ENVOY_RETRY_HEADER};
|
||||
use common::traces::{SpanBuilder, SpanKind};
|
||||
use hermesllm::apis::openai::{ChatCompletionsRequest, Message};
|
||||
use hermesllm::{ProviderRequest, ProviderRequestType};
|
||||
use hermesllm::apis::openai::{Message};
|
||||
use hyper::header::HeaderMap;
|
||||
use opentelemetry::trace::TraceContextExt;
|
||||
use tracing::{debug, info, warn};
|
||||
|
|
@ -468,14 +469,15 @@ impl PipelineProcessor {
|
|||
pub async fn invoke_terminal_agent(
|
||||
&self,
|
||||
messages: &[Message],
|
||||
original_request: &ChatCompletionsRequest,
|
||||
mut original_request: ProviderRequestType,
|
||||
terminal_agent: &Agent,
|
||||
request_headers: &HeaderMap,
|
||||
) -> Result<reqwest::Response, PipelineError> {
|
||||
let mut request = original_request.clone();
|
||||
request.messages = messages.to_vec();
|
||||
// let mut request = original_request.clone();
|
||||
original_request.set_messages(messages);
|
||||
|
||||
let request_body = serde_json::to_string(&request)?;
|
||||
let request_body = ProviderRequestType::to_bytes(&original_request).unwrap();
|
||||
// let request_body = serde_json::to_string(&request)?;
|
||||
debug!("Sending request to terminal agent {}", terminal_agent.id);
|
||||
|
||||
let mut agent_headers = request_headers.clone();
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
use brightstaff::handlers::agent_chat_completions::agent_chat;
|
||||
use brightstaff::handlers::function_calling::function_calling_chat_handler;
|
||||
use brightstaff::handlers::llm::llm_chat;
|
||||
use brightstaff::handlers::models::list_models;
|
||||
use brightstaff::handlers::function_calling::{function_calling_chat_handler};
|
||||
use brightstaff::router::llm_router::RouterService;
|
||||
use brightstaff::utils::tracing::init_tracer;
|
||||
use bytes::Bytes;
|
||||
|
|
@ -105,20 +105,23 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
info!("Tracing configuration found in arch_config.yaml");
|
||||
Some(true)
|
||||
} else {
|
||||
info!("No tracing configuration in arch_config.yaml, will check OTEL_TRACING_ENABLED env var");
|
||||
info!(
|
||||
"No tracing configuration in arch_config.yaml, will check OTEL_TRACING_ENABLED env var"
|
||||
);
|
||||
None
|
||||
};
|
||||
let trace_collector = Arc::new(TraceCollector::new(tracing_enabled));
|
||||
let _flusher_handle = trace_collector.clone().start_background_flusher();
|
||||
|
||||
|
||||
loop {
|
||||
let (stream, _) = listener.accept().await?;
|
||||
let peer_addr = stream.peer_addr()?;
|
||||
let io = TokioIo::new(stream);
|
||||
|
||||
let router_service: Arc<RouterService> = Arc::clone(&router_service);
|
||||
let model_aliases: Arc<Option<std::collections::HashMap<String, common::configuration::ModelAlias>>> = Arc::clone(&model_aliases);
|
||||
let model_aliases: Arc<
|
||||
Option<std::collections::HashMap<String, common::configuration::ModelAlias>>,
|
||||
> = Arc::clone(&model_aliases);
|
||||
let llm_provider_url = llm_provider_url.clone();
|
||||
|
||||
let llm_providers = llm_providers.clone();
|
||||
|
|
@ -136,18 +139,18 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
let trace_collector = trace_collector.clone();
|
||||
|
||||
async move {
|
||||
match (req.method(), req.uri().path()) {
|
||||
(&Method::POST, CHAT_COMPLETIONS_PATH | MESSAGES_PATH | OPENAI_RESPONSES_API_PATH) => {
|
||||
let fully_qualified_url =
|
||||
format!("{}{}", llm_provider_url, req.uri().path());
|
||||
llm_chat(req, router_service, fully_qualified_url, model_aliases, llm_providers, trace_collector)
|
||||
.with_context(parent_cx)
|
||||
.await
|
||||
}
|
||||
(&Method::POST, "/agents/v1/chat/completions") => {
|
||||
let fully_qualified_url =
|
||||
format!("{}{}", llm_provider_url, req.uri().path());
|
||||
agent_chat(
|
||||
let path = req.uri().path();
|
||||
|
||||
// Check if path starts with /agents
|
||||
if path.starts_with("/agents") {
|
||||
// Check if it matches one of the agent API paths
|
||||
let stripped_path = path.strip_prefix("/agents").unwrap();
|
||||
if matches!(
|
||||
stripped_path,
|
||||
CHAT_COMPLETIONS_PATH | MESSAGES_PATH | OPENAI_RESPONSES_API_PATH
|
||||
) {
|
||||
let fully_qualified_url = format!("{}{}", llm_provider_url, stripped_path);
|
||||
return agent_chat(
|
||||
req,
|
||||
router_service,
|
||||
fully_qualified_url,
|
||||
|
|
@ -156,6 +159,26 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
trace_collector,
|
||||
)
|
||||
.with_context(parent_cx)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
match (req.method(), path) {
|
||||
(
|
||||
&Method::POST,
|
||||
CHAT_COMPLETIONS_PATH | MESSAGES_PATH | OPENAI_RESPONSES_API_PATH,
|
||||
) => {
|
||||
let fully_qualified_url =
|
||||
format!("{}{}", llm_provider_url, req.uri().path());
|
||||
llm_chat(
|
||||
req,
|
||||
router_service,
|
||||
fully_qualified_url,
|
||||
model_aliases,
|
||||
llm_providers,
|
||||
trace_collector,
|
||||
)
|
||||
.with_context(parent_cx)
|
||||
.await
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -49,6 +49,353 @@ pub trait ProviderRequest: Send + Sync {
|
|||
fn get_temperature(&self) -> Option<f32>;
|
||||
}
|
||||
|
||||
impl ProviderRequestType {
|
||||
/// Get message history as OpenAI Message format
|
||||
/// This is useful for processing chat history across different provider formats
|
||||
pub fn get_message_history(&self) -> Vec<crate::apis::openai::Message> {
|
||||
use crate::apis::openai::{Message, MessageContent, Role};
|
||||
|
||||
match self {
|
||||
Self::ChatCompletionsRequest(r) => r.messages.clone(),
|
||||
Self::MessagesRequest(r) => {
|
||||
// Convert Anthropic messages to OpenAI format
|
||||
let mut openai_messages = Vec::new();
|
||||
|
||||
// Add system prompt as system message if present
|
||||
if let Some(system) = &r.system {
|
||||
openai_messages.push(system.clone().into());
|
||||
}
|
||||
|
||||
// Convert each Anthropic message to OpenAI format
|
||||
for msg in &r.messages {
|
||||
if let Ok(converted_msgs) = TryInto::<Vec<Message>>::try_into(msg.clone()) {
|
||||
openai_messages.extend(converted_msgs);
|
||||
}
|
||||
}
|
||||
|
||||
openai_messages
|
||||
}
|
||||
Self::BedrockConverse(r) => {
|
||||
// Convert Bedrock messages to OpenAI format
|
||||
let mut openai_messages = Vec::new();
|
||||
|
||||
// Add system messages if present
|
||||
if let Some(system) = &r.system {
|
||||
for sys_block in system {
|
||||
match sys_block {
|
||||
crate::apis::amazon_bedrock::SystemContentBlock::Text { text } => {
|
||||
openai_messages.push(Message {
|
||||
role: Role::System,
|
||||
content: MessageContent::Text(text.clone()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
});
|
||||
}
|
||||
_ => {} // Skip other system content types
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Convert conversation messages
|
||||
if let Some(messages) = &r.messages {
|
||||
for msg in messages {
|
||||
let role = match msg.role {
|
||||
crate::apis::amazon_bedrock::ConversationRole::User => Role::User,
|
||||
crate::apis::amazon_bedrock::ConversationRole::Assistant => Role::Assistant,
|
||||
};
|
||||
|
||||
// Extract text from content blocks
|
||||
let content = msg.content.iter()
|
||||
.filter_map(|block| {
|
||||
if let crate::apis::amazon_bedrock::ContentBlock::Text { text } = block {
|
||||
Some(text.clone())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
|
||||
openai_messages.push(Message {
|
||||
role,
|
||||
content: MessageContent::Text(content),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
openai_messages
|
||||
}
|
||||
Self::BedrockConverseStream(r) => {
|
||||
// Same as BedrockConverse
|
||||
let mut openai_messages = Vec::new();
|
||||
|
||||
if let Some(system) = &r.system {
|
||||
for sys_block in system {
|
||||
match sys_block {
|
||||
crate::apis::amazon_bedrock::SystemContentBlock::Text { text } => {
|
||||
openai_messages.push(Message {
|
||||
role: Role::System,
|
||||
content: MessageContent::Text(text.clone()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
});
|
||||
}
|
||||
_ => {} // Skip other system content types
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(messages) = &r.messages {
|
||||
for msg in messages {
|
||||
let role = match msg.role {
|
||||
crate::apis::amazon_bedrock::ConversationRole::User => Role::User,
|
||||
crate::apis::amazon_bedrock::ConversationRole::Assistant => Role::Assistant,
|
||||
};
|
||||
|
||||
let content = msg.content.iter()
|
||||
.filter_map(|block| {
|
||||
if let crate::apis::amazon_bedrock::ContentBlock::Text { text } = block {
|
||||
Some(text.clone())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
|
||||
openai_messages.push(Message {
|
||||
role,
|
||||
content: MessageContent::Text(content),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
openai_messages
|
||||
}
|
||||
Self::ResponsesAPIRequest(r) => {
|
||||
// Convert ResponsesAPIRequest input to a user message
|
||||
let mut openai_messages = Vec::new();
|
||||
|
||||
// Add instructions as system message if present
|
||||
if let Some(instructions) = &r.instructions {
|
||||
openai_messages.push(Message {
|
||||
role: Role::System,
|
||||
content: MessageContent::Text(instructions.clone()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
});
|
||||
}
|
||||
|
||||
// Convert input to messages
|
||||
use crate::apis::openai_responses::{InputParam, InputItem};
|
||||
match &r.input {
|
||||
InputParam::Text(text) => {
|
||||
openai_messages.push(Message {
|
||||
role: Role::User,
|
||||
content: MessageContent::Text(text.clone()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
});
|
||||
}
|
||||
InputParam::Items(items) => {
|
||||
for item in items {
|
||||
match item {
|
||||
InputItem::Message(msg) => {
|
||||
// Convert message role
|
||||
let role = match msg.role {
|
||||
crate::apis::openai_responses::MessageRole::User => Role::User,
|
||||
crate::apis::openai_responses::MessageRole::Assistant => Role::Assistant,
|
||||
crate::apis::openai_responses::MessageRole::System => Role::System,
|
||||
crate::apis::openai_responses::MessageRole::Developer => Role::System, // Map developer to system
|
||||
};
|
||||
|
||||
// Extract text from message content
|
||||
let content = msg.content.iter()
|
||||
.filter_map(|c| {
|
||||
if let crate::apis::openai_responses::InputContent::InputText { text } = c {
|
||||
Some(text.clone())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
|
||||
openai_messages.push(Message {
|
||||
role,
|
||||
content: MessageContent::Text(content),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
openai_messages
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Set message history from OpenAI Message format
|
||||
/// This converts OpenAI messages to the appropriate format for each provider type
|
||||
pub fn set_messages(&mut self, messages: &[crate::apis::openai::Message]) {
|
||||
match self {
|
||||
Self::ChatCompletionsRequest(r) => {
|
||||
r.messages = messages.to_vec();
|
||||
}
|
||||
Self::MessagesRequest(r) => {
|
||||
// Convert OpenAI messages to Anthropic format
|
||||
// Separate system messages from regular messages
|
||||
let mut system_messages = Vec::new();
|
||||
let mut regular_messages = Vec::new();
|
||||
|
||||
for msg in messages {
|
||||
if msg.role == crate::apis::openai::Role::System {
|
||||
system_messages.push(msg.clone());
|
||||
} else {
|
||||
regular_messages.push(msg.clone());
|
||||
}
|
||||
}
|
||||
|
||||
// Set system prompt if there are system messages
|
||||
if !system_messages.is_empty() {
|
||||
// Combine all system messages into one
|
||||
let system_text = system_messages.iter()
|
||||
.filter_map(|msg| {
|
||||
if let crate::apis::openai::MessageContent::Text(text) = &msg.content {
|
||||
Some(text.as_str())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
|
||||
r.system = Some(crate::apis::anthropic::MessagesSystemPrompt::Single(system_text));
|
||||
}
|
||||
|
||||
// Convert regular messages
|
||||
r.messages = regular_messages.iter()
|
||||
.filter_map(|msg| {
|
||||
msg.clone().try_into().ok()
|
||||
})
|
||||
.collect();
|
||||
}
|
||||
Self::BedrockConverse(r) | Self::BedrockConverseStream(r) => {
|
||||
// Convert OpenAI messages to Bedrock format
|
||||
use crate::apis::amazon_bedrock::{ContentBlock, ConversationRole, SystemContentBlock};
|
||||
|
||||
let mut system_blocks = Vec::new();
|
||||
let mut bedrock_messages = Vec::new();
|
||||
|
||||
for msg in messages {
|
||||
match msg.role {
|
||||
crate::apis::openai::Role::System => {
|
||||
if let crate::apis::openai::MessageContent::Text(text) = &msg.content {
|
||||
system_blocks.push(SystemContentBlock::Text { text: text.clone() });
|
||||
}
|
||||
}
|
||||
crate::apis::openai::Role::User | crate::apis::openai::Role::Assistant => {
|
||||
let role = match msg.role {
|
||||
crate::apis::openai::Role::User => ConversationRole::User,
|
||||
crate::apis::openai::Role::Assistant => ConversationRole::Assistant,
|
||||
_ => continue,
|
||||
};
|
||||
|
||||
let content = if let crate::apis::openai::MessageContent::Text(text) = &msg.content {
|
||||
vec![ContentBlock::Text { text: text.clone() }]
|
||||
} else {
|
||||
vec![]
|
||||
};
|
||||
|
||||
bedrock_messages.push(crate::apis::amazon_bedrock::Message {
|
||||
role,
|
||||
content,
|
||||
});
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
if !system_blocks.is_empty() {
|
||||
r.system = Some(system_blocks);
|
||||
}
|
||||
r.messages = Some(bedrock_messages);
|
||||
}
|
||||
Self::ResponsesAPIRequest(r) => {
|
||||
// For ResponsesAPI, we need to convert messages back to input format
|
||||
// Extract system messages as instructions
|
||||
let system_text = messages.iter()
|
||||
.filter(|msg| msg.role == crate::apis::openai::Role::System)
|
||||
.filter_map(|msg| {
|
||||
if let crate::apis::openai::MessageContent::Text(text) = &msg.content {
|
||||
Some(text.as_str())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
|
||||
if !system_text.is_empty() {
|
||||
r.instructions = Some(system_text);
|
||||
}
|
||||
|
||||
// Convert user/assistant messages to InputParam
|
||||
// For simplicity, we'll use the last user message as the input
|
||||
// or combine all non-system messages
|
||||
let input_messages: Vec<_> = messages.iter()
|
||||
.filter(|msg| msg.role != crate::apis::openai::Role::System)
|
||||
.collect();
|
||||
|
||||
if !input_messages.is_empty() {
|
||||
// If there's only one message, use Text format
|
||||
if input_messages.len() == 1 {
|
||||
if let crate::apis::openai::MessageContent::Text(text) = &input_messages[0].content {
|
||||
r.input = crate::apis::openai_responses::InputParam::Text(text.clone());
|
||||
}
|
||||
} else {
|
||||
// Multiple messages - combine them as text for now
|
||||
// A more sophisticated approach would use InputParam::Items
|
||||
let combined_text = input_messages.iter()
|
||||
.filter_map(|msg| {
|
||||
if let crate::apis::openai::MessageContent::Text(text) = &msg.content {
|
||||
Some(format!("{}: {}",
|
||||
match msg.role {
|
||||
crate::apis::openai::Role::User => "User",
|
||||
crate::apis::openai::Role::Assistant => "Assistant",
|
||||
_ => "Unknown",
|
||||
},
|
||||
text
|
||||
))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
|
||||
r.input = crate::apis::openai_responses::InputParam::Text(combined_text);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ProviderRequest for ProviderRequestType {
|
||||
fn model(&self) -> &str {
|
||||
match self {
|
||||
|
|
@ -934,4 +1281,131 @@ mod tests {
|
|||
.message
|
||||
.contains("OpenAI ChatCompletions, Anthropic Messages, and OpenAI Responses"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_get_message_history_chat_completions() {
|
||||
use crate::apis::openai::{Message, MessageContent, Role};
|
||||
|
||||
let chat_req = ChatCompletionsRequest {
|
||||
model: "gpt-4".to_string(),
|
||||
messages: vec![
|
||||
Message {
|
||||
role: Role::System,
|
||||
content: MessageContent::Text("You are helpful".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
Message {
|
||||
role: Role::User,
|
||||
content: MessageContent::Text("Hello!".to_string()),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let provider_req = ProviderRequestType::ChatCompletionsRequest(chat_req);
|
||||
let messages = provider_req.get_message_history();
|
||||
|
||||
assert_eq!(messages.len(), 2);
|
||||
assert_eq!(messages[0].role, Role::System);
|
||||
assert_eq!(messages[1].role, Role::User);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_get_message_history_anthropic_messages() {
|
||||
use crate::apis::anthropic::{
|
||||
MessagesMessage, MessagesMessageContent, MessagesRequest, MessagesRole,
|
||||
MessagesSystemPrompt,
|
||||
};
|
||||
|
||||
let anthropic_req = MessagesRequest {
|
||||
model: "claude-3-sonnet".to_string(),
|
||||
messages: vec![MessagesMessage {
|
||||
role: MessagesRole::User,
|
||||
content: MessagesMessageContent::Single("Hello!".to_string()),
|
||||
}],
|
||||
system: Some(MessagesSystemPrompt::Single(
|
||||
"You are helpful".to_string(),
|
||||
)),
|
||||
max_tokens: 100,
|
||||
container: None,
|
||||
mcp_servers: None,
|
||||
metadata: None,
|
||||
service_tier: None,
|
||||
thinking: None,
|
||||
temperature: None,
|
||||
top_p: None,
|
||||
top_k: None,
|
||||
stream: None,
|
||||
stop_sequences: None,
|
||||
tools: None,
|
||||
tool_choice: None,
|
||||
};
|
||||
|
||||
let provider_req = ProviderRequestType::MessagesRequest(anthropic_req);
|
||||
let messages = provider_req.get_message_history();
|
||||
|
||||
// Should have system message + user message
|
||||
assert_eq!(messages.len(), 2);
|
||||
assert_eq!(
|
||||
messages[0].role,
|
||||
crate::apis::openai::Role::System
|
||||
);
|
||||
assert_eq!(
|
||||
messages[1].role,
|
||||
crate::apis::openai::Role::User
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_get_message_history_responses_api() {
|
||||
use crate::apis::openai_responses::{InputParam, ResponsesAPIRequest};
|
||||
|
||||
let responses_req = ResponsesAPIRequest {
|
||||
model: "gpt-4o".to_string(),
|
||||
input: InputParam::Text("Hello, world!".to_string()),
|
||||
instructions: Some("Be helpful".to_string()),
|
||||
temperature: None,
|
||||
max_output_tokens: None,
|
||||
stream: None,
|
||||
metadata: None,
|
||||
tools: None,
|
||||
tool_choice: None,
|
||||
parallel_tool_calls: None,
|
||||
modalities: None,
|
||||
user: None,
|
||||
store: None,
|
||||
reasoning_effort: None,
|
||||
include: None,
|
||||
audio: None,
|
||||
text: None,
|
||||
service_tier: None,
|
||||
top_p: None,
|
||||
top_logprobs: None,
|
||||
stream_options: None,
|
||||
truncation: None,
|
||||
conversation: None,
|
||||
previous_response_id: None,
|
||||
max_tool_calls: None,
|
||||
background: None,
|
||||
};
|
||||
|
||||
let provider_req = ProviderRequestType::ResponsesAPIRequest(responses_req);
|
||||
let messages = provider_req.get_message_history();
|
||||
|
||||
// Should have system message (instructions) + user message (input)
|
||||
assert_eq!(messages.len(), 2);
|
||||
assert_eq!(
|
||||
messages[0].role,
|
||||
crate::apis::openai::Role::System
|
||||
);
|
||||
assert_eq!(
|
||||
messages[1].role,
|
||||
crate::apis::openai::Role::User
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue