mirror of
https://github.com/katanemo/plano.git
synced 2026-06-26 15:39:40 +02:00
fixing test cases, and making sure all references to the ChatCOmpletions* objects point to the new types
This commit is contained in:
parent
df32c7e278
commit
7253a0f203
15 changed files with 224 additions and 838 deletions
|
|
@ -3,7 +3,7 @@ use std::sync::Arc;
|
||||||
use bytes::Bytes;
|
use bytes::Bytes;
|
||||||
use common::configuration::ModelUsagePreference;
|
use common::configuration::ModelUsagePreference;
|
||||||
use common::consts::ARCH_PROVIDER_HINT_HEADER;
|
use common::consts::ARCH_PROVIDER_HINT_HEADER;
|
||||||
use hermesllm::providers::openai::types::ChatCompletionsRequest;
|
use hermesllm::apis::openai::ChatCompletionsRequest;
|
||||||
use http_body_util::combinators::BoxBody;
|
use http_body_util::combinators::BoxBody;
|
||||||
use http_body_util::{BodyExt, Full, StreamBody};
|
use http_body_util::{BodyExt, Full, StreamBody};
|
||||||
use hyper::body::Frame;
|
use hyper::body::Frame;
|
||||||
|
|
@ -93,7 +93,7 @@ pub async fn chat_completions(
|
||||||
chat_completion_request.metadata.and_then(|metadata| {
|
chat_completion_request.metadata.and_then(|metadata| {
|
||||||
metadata
|
metadata
|
||||||
.get("archgw_preference_config")
|
.get("archgw_preference_config")
|
||||||
.and_then(|value| value.as_str().map(String::from))
|
.map(|value| value.to_string())
|
||||||
});
|
});
|
||||||
|
|
||||||
let usage_preferences: Option<Vec<ModelUsagePreference>> = usage_preferences_str
|
let usage_preferences: Option<Vec<ModelUsagePreference>> = usage_preferences_str
|
||||||
|
|
@ -105,9 +105,7 @@ pub async fn chat_completions(
|
||||||
.messages
|
.messages
|
||||||
.last()
|
.last()
|
||||||
.map_or("None".to_string(), |msg| {
|
.map_or("None".to_string(), |msg| {
|
||||||
msg.content.as_ref().map_or("None".to_string(), |content| {
|
msg.content.to_string().replace('\n', "\\n")
|
||||||
content.to_string().replace('\n', "\\n")
|
|
||||||
})
|
|
||||||
});
|
});
|
||||||
|
|
||||||
const MAX_MESSAGE_LENGTH: usize = 50;
|
const MAX_MESSAGE_LENGTH: usize = 50;
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
use bytes::Bytes;
|
use bytes::Bytes;
|
||||||
use common::configuration::{IntoModels, LlmProvider};
|
use common::configuration::{IntoModels, LlmProvider};
|
||||||
use hermesllm::providers::openai::types::Models;
|
use hermesllm::apis::openai::Models;
|
||||||
use http_body_util::{combinators::BoxBody, BodyExt, Full};
|
use http_body_util::{combinators::BoxBody, BodyExt, Full};
|
||||||
use hyper::{Response, StatusCode};
|
use hyper::{Response, StatusCode};
|
||||||
use serde_json;
|
use serde_json;
|
||||||
|
|
|
||||||
|
|
@ -98,7 +98,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||||
let peer_addr = stream.peer_addr()?;
|
let peer_addr = stream.peer_addr()?;
|
||||||
let io = TokioIo::new(stream);
|
let io = TokioIo::new(stream);
|
||||||
|
|
||||||
let router_service = Arc::clone(&router_service);
|
let router_service: Arc<RouterService> = Arc::clone(&router_service);
|
||||||
let llm_provider_endpoint = llm_provider_endpoint.clone();
|
let llm_provider_endpoint = llm_provider_endpoint.clone();
|
||||||
|
|
||||||
let llm_providers = llm_providers.clone();
|
let llm_providers = llm_providers.clone();
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ use common::{
|
||||||
configuration::{LlmProvider, ModelUsagePreference, RoutingPreference},
|
configuration::{LlmProvider, ModelUsagePreference, RoutingPreference},
|
||||||
consts::ARCH_PROVIDER_HINT_HEADER,
|
consts::ARCH_PROVIDER_HINT_HEADER,
|
||||||
};
|
};
|
||||||
use hermesllm::providers::openai::types::{ChatCompletionsResponse, ContentType, Message};
|
use hermesllm::apis::openai::{ChatCompletionsResponse, Message};
|
||||||
use hyper::header;
|
use hyper::header;
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tracing::{debug, info, warn};
|
use tracing::{debug, info, warn};
|
||||||
|
|
@ -153,9 +153,7 @@ impl RouterService {
|
||||||
return Ok(None);
|
return Ok(None);
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(ContentType::Text(content)) =
|
if let Some(content) = &chat_completion_response.choices[0].message.content {
|
||||||
&chat_completion_response.choices[0].message.content
|
|
||||||
{
|
|
||||||
let parsed_response = self
|
let parsed_response = self
|
||||||
.router_model
|
.router_model
|
||||||
.parse_response(content, &usage_preferences)?;
|
.parse_response(content, &usage_preferences)?;
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
use common::configuration::ModelUsagePreference;
|
use common::configuration::ModelUsagePreference;
|
||||||
use hermesllm::providers::openai::types::{ChatCompletionsRequest, Message};
|
use hermesllm::apis::openai::{ChatCompletionsRequest, Message};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
|
|
||||||
#[derive(Debug, Error)]
|
#[derive(Debug, Error)]
|
||||||
|
|
|
||||||
|
|
@ -2,9 +2,8 @@ use std::collections::HashMap;
|
||||||
|
|
||||||
use common::{
|
use common::{
|
||||||
configuration::{ModelUsagePreference, RoutingPreference},
|
configuration::{ModelUsagePreference, RoutingPreference},
|
||||||
consts::{SYSTEM_ROLE, TOOL_ROLE, USER_ROLE},
|
|
||||||
};
|
};
|
||||||
use hermesllm::providers::openai::types::{ChatCompletionsRequest, ContentType, Message};
|
use hermesllm::apis::openai::{ChatCompletionsRequest, MessageContent, Message, Role};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use tracing::{debug, warn};
|
use tracing::{debug, warn};
|
||||||
|
|
||||||
|
|
@ -80,7 +79,9 @@ impl RouterModel for RouterModelV1 {
|
||||||
// when role == tool its tool call response
|
// when role == tool its tool call response
|
||||||
let messages_vec = messages
|
let messages_vec = messages
|
||||||
.iter()
|
.iter()
|
||||||
.filter(|m| m.role != SYSTEM_ROLE && m.role != TOOL_ROLE && m.content.is_some())
|
.filter(|m| {
|
||||||
|
m.role != Role::System && m.role != Role::Tool && !m.content.to_string().is_empty()
|
||||||
|
})
|
||||||
.collect::<Vec<&Message>>();
|
.collect::<Vec<&Message>>();
|
||||||
|
|
||||||
// Following code is to ensure that the conversation does not exceed max token length
|
// Following code is to ensure that the conversation does not exceed max token length
|
||||||
|
|
@ -88,13 +89,7 @@ impl RouterModel for RouterModelV1 {
|
||||||
let mut token_count = ARCH_ROUTER_V1_SYSTEM_PROMPT.len() / TOKEN_LENGTH_DIVISOR;
|
let mut token_count = ARCH_ROUTER_V1_SYSTEM_PROMPT.len() / TOKEN_LENGTH_DIVISOR;
|
||||||
let mut selected_messages_list_reversed: Vec<&Message> = vec![];
|
let mut selected_messages_list_reversed: Vec<&Message> = vec![];
|
||||||
for (selected_messsage_count, message) in messages_vec.iter().rev().enumerate() {
|
for (selected_messsage_count, message) in messages_vec.iter().rev().enumerate() {
|
||||||
let message_token_count = message
|
let message_token_count = message.content.to_string().len() / TOKEN_LENGTH_DIVISOR;
|
||||||
.content
|
|
||||||
.as_ref()
|
|
||||||
.unwrap_or(&ContentType::Text("".to_string()))
|
|
||||||
.to_string()
|
|
||||||
.len()
|
|
||||||
/ TOKEN_LENGTH_DIVISOR;
|
|
||||||
token_count += message_token_count;
|
token_count += message_token_count;
|
||||||
if token_count > self.max_token_length {
|
if token_count > self.max_token_length {
|
||||||
debug!(
|
debug!(
|
||||||
|
|
@ -104,7 +99,7 @@ impl RouterModel for RouterModelV1 {
|
||||||
, selected_messsage_count,
|
, selected_messsage_count,
|
||||||
messages_vec.len()
|
messages_vec.len()
|
||||||
);
|
);
|
||||||
if message.role == USER_ROLE {
|
if message.role == Role::User {
|
||||||
// If message that exceeds max token length is from user, we need to keep it
|
// If message that exceeds max token length is from user, we need to keep it
|
||||||
selected_messages_list_reversed.push(message);
|
selected_messages_list_reversed.push(message);
|
||||||
}
|
}
|
||||||
|
|
@ -125,12 +120,12 @@ impl RouterModel for RouterModelV1 {
|
||||||
|
|
||||||
// ensure that first and last selected message is from user
|
// ensure that first and last selected message is from user
|
||||||
if let Some(first_message) = selected_messages_list_reversed.first() {
|
if let Some(first_message) = selected_messages_list_reversed.first() {
|
||||||
if first_message.role != USER_ROLE {
|
if first_message.role != Role::User {
|
||||||
warn!("RouterModelV1: last message in the conversation is not from user, this may lead to incorrect routing");
|
warn!("RouterModelV1: last message in the conversation is not from user, this may lead to incorrect routing");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if let Some(last_message) = selected_messages_list_reversed.last() {
|
if let Some(last_message) = selected_messages_list_reversed.last() {
|
||||||
if last_message.role != USER_ROLE {
|
if last_message.role != Role::User {
|
||||||
warn!("RouterModelV1: first message in the conversation is not from user, this may lead to incorrect routing");
|
warn!("RouterModelV1: first message in the conversation is not from user, this may lead to incorrect routing");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -143,9 +138,10 @@ impl RouterModel for RouterModelV1 {
|
||||||
Message {
|
Message {
|
||||||
role: message.role.clone(),
|
role: message.role.clone(),
|
||||||
// we can unwrap here because we have already filtered out messages without content
|
// we can unwrap here because we have already filtered out messages without content
|
||||||
content: Some(ContentType::Text(
|
content: MessageContent::Text(message.content.to_string()),
|
||||||
message.content.as_ref().unwrap().to_string(),
|
name: None,
|
||||||
)),
|
tool_calls: None,
|
||||||
|
tool_call_id: None,
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
.collect::<Vec<Message>>();
|
.collect::<Vec<Message>>();
|
||||||
|
|
@ -160,8 +156,11 @@ impl RouterModel for RouterModelV1 {
|
||||||
ChatCompletionsRequest {
|
ChatCompletionsRequest {
|
||||||
model: self.routing_model.clone(),
|
model: self.routing_model.clone(),
|
||||||
messages: vec![Message {
|
messages: vec![Message {
|
||||||
content: Some(ContentType::Text(router_message)),
|
content: MessageContent::Text(router_message),
|
||||||
role: USER_ROLE.to_string(),
|
role: Role::User,
|
||||||
|
name: None,
|
||||||
|
tool_calls: None,
|
||||||
|
tool_call_id: None,
|
||||||
}],
|
}],
|
||||||
temperature: Some(0.01),
|
temperature: Some(0.01),
|
||||||
..Default::default()
|
..Default::default()
|
||||||
|
|
@ -347,9 +346,9 @@ Based on your analysis, provide your response in the following JSON formats if y
|
||||||
|
|
||||||
let req = router.generate_request(&conversation, &None);
|
let req = router.generate_request(&conversation, &None);
|
||||||
|
|
||||||
let prompt = req.messages[0].content.as_ref().unwrap();
|
let prompt = req.messages[0].content.to_string();
|
||||||
|
|
||||||
assert_eq!(expected_prompt, prompt.to_string());
|
assert_eq!(expected_prompt, prompt);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
@ -412,9 +411,9 @@ Based on your analysis, provide your response in the following JSON formats if y
|
||||||
}]);
|
}]);
|
||||||
let req = router.generate_request(&conversation, &usage_preferences);
|
let req = router.generate_request(&conversation, &usage_preferences);
|
||||||
|
|
||||||
let prompt = req.messages[0].content.as_ref().unwrap();
|
let prompt = req.messages[0].content.to_string();
|
||||||
|
|
||||||
assert_eq!(expected_prompt, prompt.to_string());
|
assert_eq!(expected_prompt, prompt);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
@ -472,9 +471,9 @@ Based on your analysis, provide your response in the following JSON formats if y
|
||||||
|
|
||||||
let req = router.generate_request(&conversation, &None);
|
let req = router.generate_request(&conversation, &None);
|
||||||
|
|
||||||
let prompt = req.messages[0].content.as_ref().unwrap();
|
let prompt = req.messages[0].content.to_string();
|
||||||
|
|
||||||
assert_eq!(expected_prompt, prompt.to_string());
|
assert_eq!(expected_prompt, prompt);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
@ -533,9 +532,9 @@ Based on your analysis, provide your response in the following JSON formats if y
|
||||||
|
|
||||||
let req = router.generate_request(&conversation, &None);
|
let req = router.generate_request(&conversation, &None);
|
||||||
|
|
||||||
let prompt = req.messages[0].content.as_ref().unwrap();
|
let prompt = req.messages[0].content.to_string();
|
||||||
|
|
||||||
assert_eq!(expected_prompt, prompt.to_string());
|
assert_eq!(expected_prompt, prompt);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
@ -601,9 +600,9 @@ Based on your analysis, provide your response in the following JSON formats if y
|
||||||
|
|
||||||
let req = router.generate_request(&conversation, &None);
|
let req = router.generate_request(&conversation, &None);
|
||||||
|
|
||||||
let prompt = req.messages[0].content.as_ref().unwrap();
|
let prompt = req.messages[0].content.to_string();
|
||||||
|
|
||||||
assert_eq!(expected_prompt, prompt.to_string());
|
assert_eq!(expected_prompt, prompt);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
@ -670,9 +669,9 @@ Based on your analysis, provide your response in the following JSON formats if y
|
||||||
|
|
||||||
let req = router.generate_request(&conversation, &None);
|
let req = router.generate_request(&conversation, &None);
|
||||||
|
|
||||||
let prompt = req.messages[0].content.as_ref().unwrap();
|
let prompt = req.messages[0].content.to_string();
|
||||||
|
|
||||||
assert_eq!(expected_prompt, prompt.to_string());
|
assert_eq!(expected_prompt, prompt);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
@ -716,14 +715,14 @@ Based on your analysis, provide your response in the following JSON formats if y
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"content": null,
|
"content": "",
|
||||||
"tool_calls": [
|
"tool_calls": [
|
||||||
{
|
{
|
||||||
"id": "toolcall-abc123",
|
"id": "toolcall-abc123",
|
||||||
"type": "function",
|
"type": "function",
|
||||||
"function": {
|
"function": {
|
||||||
"name": "get_weather",
|
"name": "get_weather",
|
||||||
"arguments": { "location": "Tokyo" }
|
"arguments": "{ \"location\": \"Tokyo\" }"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
@ -763,11 +762,11 @@ Based on your analysis, provide your response in the following JSON formats if y
|
||||||
|
|
||||||
let conversation: Vec<Message> = serde_json::from_str(conversation_str).unwrap();
|
let conversation: Vec<Message> = serde_json::from_str(conversation_str).unwrap();
|
||||||
|
|
||||||
let req = router.generate_request(&conversation, &None);
|
let req: ChatCompletionsRequest = router.generate_request(&conversation, &None);
|
||||||
|
|
||||||
let prompt = req.messages[0].content.as_ref().unwrap();
|
let prompt = req.messages[0].content.to_string();
|
||||||
|
|
||||||
assert_eq!(expected_prompt, prompt.to_string());
|
assert_eq!(expected_prompt, prompt);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
use hermesllm::providers::openai::types::{ModelDetail, ModelObject, Models};
|
use hermesllm::apis::openai::{ModelDetail, ModelObject, Models};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::fmt::Display;
|
use std::fmt::Display;
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
use proxy_wasm::types::Status;
|
use proxy_wasm::types::Status;
|
||||||
|
|
||||||
use crate::{api::open_ai::ChatCompletionChunkResponseError, ratelimit};
|
use crate::{api::open_ai::ChatCompletionChunkResponseError, ratelimit};
|
||||||
use hermesllm::providers::openai::types::OpenAIError;
|
use hermesllm::apis::openai::OpenAIError;
|
||||||
|
|
||||||
#[derive(thiserror::Error, Debug)]
|
#[derive(thiserror::Error, Debug)]
|
||||||
pub enum ClientError {
|
pub enum ClientError {
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,8 @@ use serde::{Deserialize, Serialize};
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
use serde_with::skip_serializing_none;
|
use serde_with::skip_serializing_none;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
use std::fmt::Display;
|
||||||
|
use thiserror::Error;
|
||||||
|
|
||||||
use crate::{providers::ProviderRequestError, ConversionMode, ProviderRequest};
|
use crate::{providers::ProviderRequestError, ConversionMode, ProviderRequest};
|
||||||
use super::ApiDefinition;
|
use super::ApiDefinition;
|
||||||
|
|
@ -116,8 +118,8 @@ pub enum Role {
|
||||||
#[skip_serializing_none]
|
#[skip_serializing_none]
|
||||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||||
pub struct Message {
|
pub struct Message {
|
||||||
pub content: MessageContent,
|
|
||||||
pub role: Role,
|
pub role: Role,
|
||||||
|
pub content: MessageContent,
|
||||||
pub name: Option<String>,
|
pub name: Option<String>,
|
||||||
/// Tool calls made by the assistant (only present for assistant role)
|
/// Tool calls made by the assistant (only present for assistant role)
|
||||||
pub tool_calls: Option<Vec<ToolCall>>,
|
pub tool_calls: Option<Vec<ToolCall>>,
|
||||||
|
|
@ -171,6 +173,28 @@ pub enum MessageContent {
|
||||||
Parts(Vec<ContentPart>),
|
Parts(Vec<ContentPart>),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Display for MessageContent {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
match self {
|
||||||
|
MessageContent::Text(text) => write!(f, "{}", text),
|
||||||
|
MessageContent::Parts(parts) => {
|
||||||
|
let text_parts: Vec<String> = parts
|
||||||
|
.iter()
|
||||||
|
.filter_map(|part| match part {
|
||||||
|
ContentPart::Text { text } => Some(text.clone()),
|
||||||
|
ContentPart::ImageUrl { .. } => {
|
||||||
|
// skip image URLs or their data in text representation
|
||||||
|
None
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
let combined_text = text_parts.join("\n");
|
||||||
|
write!(f, "{}", combined_text)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Individual content part within a message (text or image)
|
/// Individual content part within a message (text or image)
|
||||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||||
#[serde(tag = "type")]
|
#[serde(tag = "type")]
|
||||||
|
|
@ -560,6 +584,26 @@ impl TokenUsage for Usage {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct ModelDetail {
|
||||||
|
pub id: String,
|
||||||
|
pub object: String,
|
||||||
|
pub created: usize,
|
||||||
|
pub owned_by: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub enum ModelObject {
|
||||||
|
#[serde(rename = "list")]
|
||||||
|
List,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct Models {
|
||||||
|
pub object: ModelObject,
|
||||||
|
pub data: Vec<ModelDetail>,
|
||||||
|
}
|
||||||
|
|
||||||
// Error type for streaming operations
|
// Error type for streaming operations
|
||||||
#[derive(Debug, thiserror::Error)]
|
#[derive(Debug, thiserror::Error)]
|
||||||
pub enum OpenAIStreamError {
|
pub enum OpenAIStreamError {
|
||||||
|
|
@ -571,6 +615,22 @@ pub enum OpenAIStreamError {
|
||||||
InvalidStreamingData(String),
|
InvalidStreamingData(String),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Error)]
|
||||||
|
pub enum OpenAIError {
|
||||||
|
#[error("json error: {0}")]
|
||||||
|
JsonParseError(#[from] serde_json::Error),
|
||||||
|
#[error("utf8 parsing error: {0}")]
|
||||||
|
Utf8Error(#[from] std::str::Utf8Error),
|
||||||
|
#[error("invalid streaming data err {source}, data: {data}")]
|
||||||
|
InvalidStreamingData {
|
||||||
|
source: serde_json::Error,
|
||||||
|
data: String,
|
||||||
|
},
|
||||||
|
#[error("unsupported provider: {provider}")]
|
||||||
|
UnsupportedProvider { provider: String },
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
/// SSE-based streaming iterator for OpenAI chat completions
|
/// SSE-based streaming iterator for OpenAI chat completions
|
||||||
/// Implements ProviderStreamResponseIter directly
|
/// Implements ProviderStreamResponseIter directly
|
||||||
pub struct SseChatCompletionIter<I>
|
pub struct SseChatCompletionIter<I>
|
||||||
|
|
|
||||||
|
|
@ -1,114 +0,0 @@
|
||||||
use serde_json::Value;
|
|
||||||
|
|
||||||
use crate::providers::openai::types::{ChatCompletionsRequest, Message, StreamOptions};
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub struct OpenAIRequestBuilder {
|
|
||||||
model: String,
|
|
||||||
messages: Vec<Message>,
|
|
||||||
temperature: Option<f32>,
|
|
||||||
top_p: Option<f32>,
|
|
||||||
n: Option<u32>,
|
|
||||||
max_tokens: Option<u32>,
|
|
||||||
stream: Option<bool>,
|
|
||||||
stop: Option<Vec<String>>,
|
|
||||||
presence_penalty: Option<f32>,
|
|
||||||
frequency_penalty: Option<f32>,
|
|
||||||
stream_options: Option<StreamOptions>,
|
|
||||||
tools: Option<Vec<Value>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl OpenAIRequestBuilder {
|
|
||||||
pub fn new(model: impl Into<String>, messages: Vec<Message>) -> Self {
|
|
||||||
Self {
|
|
||||||
model: model.into(),
|
|
||||||
messages,
|
|
||||||
temperature: None,
|
|
||||||
top_p: None,
|
|
||||||
n: None,
|
|
||||||
max_tokens: None,
|
|
||||||
stream: None,
|
|
||||||
stop: None,
|
|
||||||
presence_penalty: None,
|
|
||||||
frequency_penalty: None,
|
|
||||||
stream_options: None,
|
|
||||||
tools: None,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn temperature(mut self, temperature: f32) -> Self {
|
|
||||||
self.temperature = Some(temperature);
|
|
||||||
self
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn top_p(mut self, top_p: f32) -> Self {
|
|
||||||
self.top_p = Some(top_p);
|
|
||||||
self
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn n(mut self, n: u32) -> Self {
|
|
||||||
self.n = Some(n);
|
|
||||||
self
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn max_tokens(mut self, max_tokens: u32) -> Self {
|
|
||||||
self.max_tokens = Some(max_tokens);
|
|
||||||
self
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn stream(mut self, stream: bool) -> Self {
|
|
||||||
self.stream = Some(stream);
|
|
||||||
self
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn stop(mut self, stop: Vec<String>) -> Self {
|
|
||||||
self.stop = Some(stop);
|
|
||||||
self
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn presence_penalty(mut self, presence_penalty: f32) -> Self {
|
|
||||||
self.presence_penalty = Some(presence_penalty);
|
|
||||||
self
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn frequency_penalty(mut self, frequency_penalty: f32) -> Self {
|
|
||||||
self.frequency_penalty = Some(frequency_penalty);
|
|
||||||
self
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn stream_options(mut self, include_usage: bool) -> Self {
|
|
||||||
self.stream = Some(true);
|
|
||||||
self.stream_options = Some(StreamOptions { include_usage });
|
|
||||||
self
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn tools(mut self, tools: Vec<Value>) -> Self {
|
|
||||||
self.tools = Some(tools);
|
|
||||||
self
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn build(self) -> Result<ChatCompletionsRequest, &'static str> {
|
|
||||||
let request = ChatCompletionsRequest {
|
|
||||||
model: self.model,
|
|
||||||
messages: self.messages,
|
|
||||||
temperature: self.temperature,
|
|
||||||
top_p: self.top_p,
|
|
||||||
n: self.n,
|
|
||||||
max_tokens: self.max_tokens,
|
|
||||||
stream: self.stream,
|
|
||||||
stop: self.stop,
|
|
||||||
presence_penalty: self.presence_penalty,
|
|
||||||
frequency_penalty: self.frequency_penalty,
|
|
||||||
stream_options: self.stream_options,
|
|
||||||
tools: self.tools,
|
|
||||||
metadata: None,
|
|
||||||
};
|
|
||||||
Ok(request)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ChatCompletionsRequest {
|
|
||||||
pub fn builder(model: impl Into<String>, messages: Vec<Message>) -> OpenAIRequestBuilder {
|
|
||||||
OpenAIRequestBuilder::new(model, messages)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,10 +1,5 @@
|
||||||
pub mod builder;
|
|
||||||
pub mod types;
|
|
||||||
|
|
||||||
// Re-export the main types and builder functionality
|
// Re-export the main types and builder functionality
|
||||||
pub use crate::apis::openai::{ChatCompletionsRequest, ChatCompletionsResponse, ChatCompletionsStreamResponse};
|
pub use crate::apis::openai::{ChatCompletionsRequest, ChatCompletionsResponse, ChatCompletionsStreamResponse};
|
||||||
pub use builder::*;
|
|
||||||
pub use types::*;
|
|
||||||
|
|
||||||
// Note: The OpenAIProvider struct has been deprecated in favor of the function-based approach in traits.rs
|
// Note: The OpenAIProvider struct has been deprecated in favor of the function-based approach in traits.rs
|
||||||
// All provider functionality is now accessed through try_request_from_bytes, try_response_from_bytes, etc.
|
// All provider functionality is now accessed through try_request_from_bytes, try_response_from_bytes, etc.
|
||||||
|
|
|
||||||
|
|
@ -1,552 +0,0 @@
|
||||||
use std::collections::HashMap;
|
|
||||||
use std::fmt::Display;
|
|
||||||
|
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
use serde_json::Value;
|
|
||||||
use serde_with::skip_serializing_none;
|
|
||||||
use std::convert::TryFrom;
|
|
||||||
use std::str;
|
|
||||||
use thiserror::Error;
|
|
||||||
|
|
||||||
use crate::providers::ProviderId;
|
|
||||||
|
|
||||||
#[derive(Debug, Error)]
|
|
||||||
pub enum OpenAIError {
|
|
||||||
#[error("json error: {0}")]
|
|
||||||
JsonParseError(#[from] serde_json::Error),
|
|
||||||
#[error("utf8 parsing error: {0}")]
|
|
||||||
Utf8Error(#[from] std::str::Utf8Error),
|
|
||||||
#[error("invalid streaming data err {source}, data: {data}")]
|
|
||||||
InvalidStreamingData {
|
|
||||||
source: serde_json::Error,
|
|
||||||
data: String,
|
|
||||||
},
|
|
||||||
#[error("unsupported provider: {provider}")]
|
|
||||||
UnsupportedProvider { provider: String },
|
|
||||||
}
|
|
||||||
|
|
||||||
type Result<T> = std::result::Result<T, OpenAIError>;
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
|
||||||
pub enum MultiPartContentType {
|
|
||||||
#[serde(rename = "text")]
|
|
||||||
Text,
|
|
||||||
#[serde(rename = "image_url")]
|
|
||||||
ImageUrl,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
|
||||||
pub struct ImageUrl {
|
|
||||||
pub url: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[skip_serializing_none]
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
|
||||||
pub struct MultiPartContent {
|
|
||||||
pub text: Option<String>,
|
|
||||||
pub image_url: Option<ImageUrl>,
|
|
||||||
#[serde(rename = "type")]
|
|
||||||
pub content_type: MultiPartContentType,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
|
||||||
#[serde(untagged)]
|
|
||||||
pub enum ContentType {
|
|
||||||
Text(String),
|
|
||||||
MultiPart(Vec<MultiPartContent>),
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Display for ContentType {
|
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
||||||
match self {
|
|
||||||
ContentType::Text(text) => write!(f, "{}", text),
|
|
||||||
ContentType::MultiPart(multi_part) => {
|
|
||||||
let text_parts: Vec<String> = multi_part
|
|
||||||
.iter()
|
|
||||||
.filter_map(|part| {
|
|
||||||
if part.content_type == MultiPartContentType::Text {
|
|
||||||
part.text.clone()
|
|
||||||
} else if part.content_type == MultiPartContentType::ImageUrl {
|
|
||||||
// skip image URLs or their data in text representation
|
|
||||||
None
|
|
||||||
} else {
|
|
||||||
panic!("Unsupported content type: {:?}", part.content_type);
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
let combined_text = text_parts.join("\n");
|
|
||||||
write!(f, "{}", combined_text)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[skip_serializing_none]
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
pub struct Message {
|
|
||||||
pub role: String,
|
|
||||||
pub content: Option<ContentType>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Message {
|
|
||||||
pub fn new(content: String) -> Self {
|
|
||||||
Self {
|
|
||||||
role: "user".to_string(),
|
|
||||||
content: Some(ContentType::Text(content)),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
pub struct StreamOptions {
|
|
||||||
pub include_usage: bool,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[skip_serializing_none]
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
|
||||||
pub struct ChatCompletionsRequest {
|
|
||||||
pub model: String,
|
|
||||||
pub messages: Vec<Message>,
|
|
||||||
pub temperature: Option<f32>,
|
|
||||||
pub top_p: Option<f32>,
|
|
||||||
pub n: Option<u32>,
|
|
||||||
pub max_tokens: Option<u32>,
|
|
||||||
pub stream: Option<bool>,
|
|
||||||
pub stop: Option<Vec<String>>,
|
|
||||||
pub presence_penalty: Option<f32>,
|
|
||||||
pub frequency_penalty: Option<f32>,
|
|
||||||
pub stream_options: Option<StreamOptions>,
|
|
||||||
pub tools: Option<Vec<Value>>,
|
|
||||||
pub metadata: Option<HashMap<String, Value>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl TryFrom<&[u8]> for ChatCompletionsRequest {
|
|
||||||
type Error = OpenAIError;
|
|
||||||
fn try_from(bytes: &[u8]) -> Result<Self> {
|
|
||||||
serde_json::from_slice(bytes).map_err(OpenAIError::from)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[skip_serializing_none]
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
pub struct ChatCompletionsResponse {
|
|
||||||
pub id: String,
|
|
||||||
pub object: String,
|
|
||||||
pub created: u64,
|
|
||||||
pub choices: Vec<Choice>,
|
|
||||||
pub usage: Option<Usage>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl TryFrom<&[u8]> for ChatCompletionsResponse {
|
|
||||||
type Error = OpenAIError;
|
|
||||||
fn try_from(bytes: &[u8]) -> Result<Self> {
|
|
||||||
serde_json::from_slice(bytes).map_err(OpenAIError::from)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'a> TryFrom<(&'a [u8], &'a ProviderId)> for ChatCompletionsResponse {
|
|
||||||
type Error = OpenAIError;
|
|
||||||
|
|
||||||
fn try_from(input: (&'a [u8], &'a ProviderId)) -> Result<Self> {
|
|
||||||
// Use input.provider as needed, if necessary
|
|
||||||
serde_json::from_slice(input.0).map_err(OpenAIError::from)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ChatCompletionsRequest {
|
|
||||||
pub fn to_bytes(&self, provider: ProviderId) -> Result<Vec<u8>> {
|
|
||||||
match provider {
|
|
||||||
ProviderId::OpenAI
|
|
||||||
| ProviderId::Arch
|
|
||||||
| ProviderId::Deepseek
|
|
||||||
| ProviderId::Mistral
|
|
||||||
| ProviderId::Groq
|
|
||||||
| ProviderId::Gemini
|
|
||||||
| ProviderId::Claude
|
|
||||||
| ProviderId::GitHub => serde_json::to_vec(self).map_err(OpenAIError::from),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[skip_serializing_none]
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
pub struct Choice {
|
|
||||||
pub index: u32,
|
|
||||||
pub message: Message,
|
|
||||||
pub finish_reason: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[skip_serializing_none]
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
pub struct Usage {
|
|
||||||
pub prompt_tokens: usize,
|
|
||||||
pub completion_tokens: usize,
|
|
||||||
pub total_tokens: usize,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[skip_serializing_none]
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
pub struct DeltaMessage {
|
|
||||||
pub role: Option<String>,
|
|
||||||
pub content: Option<ContentType>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
pub struct StreamChoice {
|
|
||||||
pub index: u32,
|
|
||||||
pub delta: DeltaMessage,
|
|
||||||
pub finish_reason: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
pub struct ChatCompletionStreamResponse {
|
|
||||||
pub id: String,
|
|
||||||
pub object: String,
|
|
||||||
pub created: u64,
|
|
||||||
pub model: String,
|
|
||||||
pub choices: Vec<StreamChoice>,
|
|
||||||
pub usage: Option<Usage>,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct SseChatCompletionIter<I>
|
|
||||||
where
|
|
||||||
I: Iterator,
|
|
||||||
I::Item: AsRef<str>,
|
|
||||||
{
|
|
||||||
lines: I,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<I> SseChatCompletionIter<I>
|
|
||||||
where
|
|
||||||
I: Iterator,
|
|
||||||
I::Item: AsRef<str>,
|
|
||||||
{
|
|
||||||
pub fn new(lines: I) -> Self {
|
|
||||||
Self { lines }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<I> Iterator for SseChatCompletionIter<I>
|
|
||||||
where
|
|
||||||
I: Iterator,
|
|
||||||
I::Item: AsRef<str>,
|
|
||||||
{
|
|
||||||
type Item = Result<ChatCompletionStreamResponse>;
|
|
||||||
|
|
||||||
fn next(&mut self) -> Option<Self::Item> {
|
|
||||||
for line in &mut self.lines {
|
|
||||||
let line = line.as_ref();
|
|
||||||
if let Some(data) = line.strip_prefix("data: ") {
|
|
||||||
let data = data.trim();
|
|
||||||
if data == "[DONE]" {
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
|
|
||||||
if data == r#"{"type": "ping"}"# {
|
|
||||||
continue; // Skip ping messages - that is usually from anthropic
|
|
||||||
}
|
|
||||||
|
|
||||||
return Some(
|
|
||||||
serde_json::from_str::<ChatCompletionStreamResponse>(data).map_err(|e| {
|
|
||||||
OpenAIError::InvalidStreamingData {
|
|
||||||
source: e,
|
|
||||||
data: data.to_string(),
|
|
||||||
}
|
|
||||||
}),
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
None
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
impl<'a> TryFrom<&'a [u8]> for SseChatCompletionIter<str::Lines<'a>> {
|
|
||||||
type Error = OpenAIError;
|
|
||||||
|
|
||||||
fn try_from(bytes: &'a [u8]) -> Result<Self> {
|
|
||||||
let s = std::str::from_utf8(bytes)?;
|
|
||||||
Ok(SseChatCompletionIter::new(s.lines()))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
pub struct ModelDetail {
|
|
||||||
pub id: String,
|
|
||||||
pub object: String,
|
|
||||||
pub created: usize,
|
|
||||||
pub owned_by: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
pub enum ModelObject {
|
|
||||||
#[serde(rename = "list")]
|
|
||||||
List,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
pub struct Models {
|
|
||||||
pub object: ModelObject,
|
|
||||||
pub data: Vec<ModelDetail>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests {
|
|
||||||
use super::*;
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_content_type_display() {
|
|
||||||
let text_content = ContentType::Text("Hello, world!".to_string());
|
|
||||||
assert_eq!(text_content.to_string(), "Hello, world!");
|
|
||||||
|
|
||||||
let multi_part_content = ContentType::MultiPart(vec![
|
|
||||||
MultiPartContent {
|
|
||||||
text: Some("This is a text part.".to_string()),
|
|
||||||
content_type: MultiPartContentType::Text,
|
|
||||||
image_url: None,
|
|
||||||
},
|
|
||||||
MultiPartContent {
|
|
||||||
text: Some("https://example.com/image.png".to_string()),
|
|
||||||
content_type: MultiPartContentType::ImageUrl,
|
|
||||||
image_url: None,
|
|
||||||
},
|
|
||||||
]);
|
|
||||||
assert_eq!(multi_part_content.to_string(), "This is a text part.");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_chat_completions_request_text_type_array() {
|
|
||||||
const CHAT_COMPLETIONS_REQUEST: &str = r#"
|
|
||||||
{
|
|
||||||
"model": "gpt-3.5-turbo",
|
|
||||||
"messages": [
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": [
|
|
||||||
{
|
|
||||||
"type": "text",
|
|
||||||
"text": "What city do you want to know the weather for?"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "text",
|
|
||||||
"text": "hello world"
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
"#;
|
|
||||||
|
|
||||||
let chat_completions_request: ChatCompletionsRequest =
|
|
||||||
serde_json::from_str(CHAT_COMPLETIONS_REQUEST).unwrap();
|
|
||||||
assert_eq!(chat_completions_request.model, "gpt-3.5-turbo");
|
|
||||||
if let Some(ContentType::MultiPart(multi_part_content)) =
|
|
||||||
chat_completions_request.messages[0].content.as_ref()
|
|
||||||
{
|
|
||||||
assert_eq!(multi_part_content.len(), 2);
|
|
||||||
assert_eq!(
|
|
||||||
multi_part_content[0].content_type,
|
|
||||||
MultiPartContentType::Text
|
|
||||||
);
|
|
||||||
assert_eq!(
|
|
||||||
multi_part_content[0].text,
|
|
||||||
Some("What city do you want to know the weather for?".to_string())
|
|
||||||
);
|
|
||||||
assert_eq!(
|
|
||||||
multi_part_content[1].content_type,
|
|
||||||
MultiPartContentType::Text
|
|
||||||
);
|
|
||||||
assert_eq!(multi_part_content[1].text, Some("hello world".to_string()));
|
|
||||||
} else {
|
|
||||||
panic!("Expected MultiPartContent");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_chat_completions_request_image_content() {
|
|
||||||
const CHAT_COMPLETIONS_REQUEST: &str = r#"
|
|
||||||
{
|
|
||||||
"stream": true,
|
|
||||||
"model": "openai/gpt-4o",
|
|
||||||
"messages": [
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": [
|
|
||||||
{
|
|
||||||
"type": "text",
|
|
||||||
"text": "describe this photo pls"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "image_url",
|
|
||||||
"image_url": {
|
|
||||||
"url": "data:image/jpeg;base64,/9j/4AAQSkZJRgABAQAAAQABAAD/...=="
|
|
||||||
}
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}"#;
|
|
||||||
|
|
||||||
let chat_completions_request: ChatCompletionsRequest =
|
|
||||||
serde_json::from_str(CHAT_COMPLETIONS_REQUEST).unwrap();
|
|
||||||
assert_eq!(chat_completions_request.model, "openai/gpt-4o");
|
|
||||||
if let Some(ContentType::MultiPart(multi_part_content)) =
|
|
||||||
chat_completions_request.messages[0].content.as_ref()
|
|
||||||
{
|
|
||||||
assert_eq!(multi_part_content.len(), 2);
|
|
||||||
assert_eq!(
|
|
||||||
multi_part_content[0].content_type,
|
|
||||||
MultiPartContentType::Text
|
|
||||||
);
|
|
||||||
assert_eq!(
|
|
||||||
multi_part_content[0].text,
|
|
||||||
Some("describe this photo pls".to_string())
|
|
||||||
);
|
|
||||||
assert_eq!(
|
|
||||||
multi_part_content[1].content_type,
|
|
||||||
MultiPartContentType::ImageUrl
|
|
||||||
);
|
|
||||||
assert_eq!(
|
|
||||||
multi_part_content[1].image_url,
|
|
||||||
Some(ImageUrl {
|
|
||||||
url: "data:image/jpeg;base64,/9j/4AAQSkZJRgABAQAAAQABAAD/...==".to_string(),
|
|
||||||
})
|
|
||||||
);
|
|
||||||
} else {
|
|
||||||
panic!("Expected MultiPartContent");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_sse_streaming() {
|
|
||||||
let json_data = r#"data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1700000000,"model":"gpt-3.5-turbo","choices":[{"index":0,"delta":{"role":"assistant"},"finish_reason":null}]}
|
|
||||||
data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1700000000,"model":"gpt-3.5-turbo","choices":[{"index":0,"delta":{"content":"Hello, how can I help you today?"},"finish_reason":null}]}
|
|
||||||
data: [DONE]"#;
|
|
||||||
|
|
||||||
let iter = SseChatCompletionIter::new(json_data.lines());
|
|
||||||
|
|
||||||
println!("Testing SSE Streaming");
|
|
||||||
for item in iter {
|
|
||||||
match item {
|
|
||||||
Ok(response) => {
|
|
||||||
println!("Received response: {:?}", response);
|
|
||||||
if response.choices.is_empty() {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
for choice in response.choices {
|
|
||||||
if let Some(content) = choice.delta.content {
|
|
||||||
println!("Content: {}", content);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
println!("Error parsing JSON: {}", e);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_sse_streaming_try_from_bytes() {
|
|
||||||
let json_data = r#"data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1700000000,"model":"gpt-3.5-turbo","choices":[{"index":0,"delta":{"role":"assistant"},"finish_reason":null}]}
|
|
||||||
data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1700000000,"model":"gpt-3.5-turbo","choices":[{"index":0,"delta":{"content":"Hello, how can I help you today?"},"finish_reason":null}]}
|
|
||||||
data: [DONE]"#;
|
|
||||||
|
|
||||||
let iter = SseChatCompletionIter::try_from(json_data.as_bytes())
|
|
||||||
.expect("Failed to create SSE iterator");
|
|
||||||
|
|
||||||
println!("Testing SSE Streaming");
|
|
||||||
for item in iter {
|
|
||||||
match item {
|
|
||||||
Ok(response) => {
|
|
||||||
println!("Received response: {:?}", response);
|
|
||||||
if response.choices.is_empty() {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
for choice in response.choices {
|
|
||||||
if let Some(content) = choice.delta.content {
|
|
||||||
println!("Content: {}", content);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
println!("Error parsing JSON: {}", e);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn parse_chat_completions_request() {
|
|
||||||
const CHAT_COMPLETIONS_REQUEST: &str = r#"
|
|
||||||
{
|
|
||||||
"model": "None",
|
|
||||||
"messages": [
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": "how is the weather in seattle"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"stream": true
|
|
||||||
} "#;
|
|
||||||
|
|
||||||
let _chat_completions_request: ChatCompletionsRequest =
|
|
||||||
ChatCompletionsRequest::try_from(CHAT_COMPLETIONS_REQUEST.as_bytes())
|
|
||||||
.expect("Failed to parse ChatCompletionsRequest");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn stream_chunk_parse_claude() {
|
|
||||||
const CHUNK_RESPONSE: &str = r#"data: {"id":"msg_01DZDMxYSgq8aPQxMQoBv6Kb","choices":[{"index":0,"delta":{"role":"assistant"}}],"created":1747685264,"model":"claude-3-7-sonnet-latest","object":"chat.completion.chunk"}
|
|
||||||
|
|
||||||
data: {"type": "ping"}
|
|
||||||
|
|
||||||
data: {"id":"msg_01DZDMxYSgq8aPQxMQoBv6Kb","choices":[{"index":0,"delta":{"content":"Hello!"}}],"created":1747685264,"model":"claude-3-7-sonnet-latest","object":"chat.completion.chunk"}
|
|
||||||
|
|
||||||
data: {"id":"msg_01DZDMxYSgq8aPQxMQoBv6Kb","choices":[{"index":0,"delta":{"content":" How can I assist you today? Whether"}}],"created":1747685264,"model":"claude-3-7-sonnet-latest","object":"chat.completion.chunk"}
|
|
||||||
|
|
||||||
data: {"id":"msg_01DZDMxYSgq8aPQxMQoBv6Kb","choices":[{"index":0,"delta":{"content":" you have a question, need information"}}],"created":1747685264,"model":"claude-3-7-sonnet-latest","object":"chat.completion.chunk"}
|
|
||||||
|
|
||||||
data: {"id":"msg_01DZDMxYSgq8aPQxMQoBv6Kb","choices":[{"index":0,"delta":{"content":", or just want to chat about"}}],"created":1747685264,"model":"claude-3-7-sonnet-latest","object":"chat.completion.chunk"}
|
|
||||||
|
|
||||||
data: {"id":"msg_01DZDMxYSgq8aPQxMQoBv6Kb","choices":[{"index":0,"delta":{"content":" something, I'm here to help. What woul"}}],"created":1747685264,"model":"claude-3-7-sonnet-latest","object":"chat.completion.chunk"}
|
|
||||||
|
|
||||||
data: {"id":"msg_01DZDMxYSgq8aPQxMQoBv6Kb","choices":[{"index":0,"delta":{"content":"d you like to talk about?"}}],"created":1747685264,"model":"claude-3-7-sonnet-latest","object":"chat.completion.chunk"}
|
|
||||||
|
|
||||||
data: {"id":"msg_01DZDMxYSgq8aPQxMQoBv6Kb","choices":[{"index":0,"delta":{},"finish_reason":"stop"}],"created":1747685264,"model":"claude-3-7-sonnet-latest","object":"chat.completion.chunk"}
|
|
||||||
|
|
||||||
data: [DONE]
|
|
||||||
"#;
|
|
||||||
|
|
||||||
let iter = SseChatCompletionIter::try_from(CHUNK_RESPONSE.as_bytes());
|
|
||||||
|
|
||||||
assert!(iter.is_ok(), "Failed to create SSE iterator");
|
|
||||||
let iter: SseChatCompletionIter<str::Lines<'_>> = iter.unwrap();
|
|
||||||
|
|
||||||
let all_text: Vec<String> = iter
|
|
||||||
.map(|item| {
|
|
||||||
let response = item.expect("Failed to parse response");
|
|
||||||
response
|
|
||||||
.choices
|
|
||||||
.into_iter()
|
|
||||||
.filter_map(|choice| choice.delta.content)
|
|
||||||
.map(|content| content.to_string())
|
|
||||||
.collect::<String>()
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
assert_eq!(
|
|
||||||
all_text.len(),
|
|
||||||
8,
|
|
||||||
"Expected 8 chunks of text, but got {}",
|
|
||||||
all_text.len()
|
|
||||||
);
|
|
||||||
|
|
||||||
assert_eq!(
|
|
||||||
all_text.join(""),
|
|
||||||
"Hello! How can I assist you today? Whether you have a question, need information, or just want to chat about something, I'm here to help. What would you like to talk about?"
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -6,59 +6,6 @@
|
||||||
use std::error::Error;
|
use std::error::Error;
|
||||||
use std::fmt;
|
use std::fmt;
|
||||||
|
|
||||||
/// Conversion mode for provider requests/responses
|
|
||||||
#[derive(Debug, Clone, Copy)]
|
|
||||||
pub enum ConversionMode {
|
|
||||||
/// Compatible: Convert between different provider formats to ensure compatibility
|
|
||||||
Compatible,
|
|
||||||
/// Passthrough: Pass requests/responses through with minimal modification
|
|
||||||
Passthrough,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Error types for provider operations
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct ProviderRequestError {
|
|
||||||
pub message: String,
|
|
||||||
pub source: Option<Box<dyn Error + Send + Sync>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct ProviderResponseError {
|
|
||||||
pub message: String,
|
|
||||||
pub source: Option<Box<dyn Error + Send + Sync>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl fmt::Display for ProviderRequestError {
|
|
||||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
|
||||||
write!(f, "Provider request error: {}", self.message)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl fmt::Display for ProviderResponseError {
|
|
||||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
|
||||||
write!(f, "Provider response error: {}", self.message)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Error for ProviderRequestError {
|
|
||||||
fn source(&self) -> Option<&(dyn Error + 'static)> {
|
|
||||||
self.source.as_ref().map(|e| e.as_ref() as &(dyn Error + 'static))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Error for ProviderResponseError {
|
|
||||||
fn source(&self) -> Option<&(dyn Error + 'static)> {
|
|
||||||
self.source.as_ref().map(|e| e.as_ref() as &(dyn Error + 'static))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Trait for token usage information
|
|
||||||
pub trait TokenUsage {
|
|
||||||
fn completion_tokens(&self) -> usize;
|
|
||||||
fn prompt_tokens(&self) -> usize;
|
|
||||||
fn total_tokens(&self) -> usize;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Trait for provider-specific request types
|
/// Trait for provider-specific request types
|
||||||
pub trait ProviderRequest: Send + Sync {
|
pub trait ProviderRequest: Send + Sync {
|
||||||
/// Extract the model name from the request
|
/// Extract the model name from the request
|
||||||
|
|
@ -107,13 +54,26 @@ pub trait ProviderStreamResponse: Send + Sync {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Trait for streaming response iterators
|
/// Trait for streaming response iterators
|
||||||
///
|
|
||||||
/// This trait ensures that implementing types are iterators that yield
|
|
||||||
/// ProviderStreamResponse results.
|
|
||||||
pub trait ProviderStreamResponseIter: Iterator<Item = Result<Box<dyn ProviderStreamResponse>, Box<dyn std::error::Error + Send + Sync>>> + Send + Sync {
|
pub trait ProviderStreamResponseIter: Iterator<Item = Result<Box<dyn ProviderStreamResponse>, Box<dyn std::error::Error + Send + Sync>>> + Send + Sync {
|
||||||
// No additional methods needed - just the Iterator constraint with proper bounds
|
// No additional methods needed - just the Iterator constraint with proper bounds
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Conversion mode for provider requests/responses
|
||||||
|
#[derive(Debug, Clone, Copy)]
|
||||||
|
pub enum ConversionMode {
|
||||||
|
/// Compatible: Convert between different provider formats to ensure compatibility
|
||||||
|
Compatible,
|
||||||
|
/// Passthrough: Pass requests/responses through with minimal modification
|
||||||
|
Passthrough,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Trait for token usage information
|
||||||
|
pub trait TokenUsage {
|
||||||
|
fn completion_tokens(&self) -> usize;
|
||||||
|
fn prompt_tokens(&self) -> usize;
|
||||||
|
fn total_tokens(&self) -> usize;
|
||||||
|
}
|
||||||
|
|
||||||
// ============================================================================
|
// ============================================================================
|
||||||
// PROVIDER FUNCTIONS - NO TRAITS, JUST PARAMETERIZED CONVERSION
|
// PROVIDER FUNCTIONS - NO TRAITS, JUST PARAMETERIZED CONVERSION
|
||||||
// ============================================================================
|
// ============================================================================
|
||||||
|
|
@ -152,14 +112,42 @@ pub trait ProviderStreamResponseIter: Iterator<Item = Result<Box<dyn ProviderStr
|
||||||
|
|
||||||
use crate::ProviderId;
|
use crate::ProviderId;
|
||||||
|
|
||||||
/// Parse request from bytes using provider ID - returns generic ProviderRequest trait object
|
// ============================================================================
|
||||||
pub fn try_request_from_bytes(bytes: &[u8], provider_id: &ProviderId) -> Result<Box<dyn ProviderRequest>, ProviderRequestError> {
|
// PROVIDER ADAPTER REGISTRY (Organizational Enhancement)
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
|
/// Provider adapter configuration
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct ProviderConfig {
|
||||||
|
pub supported_apis: &'static [&'static str],
|
||||||
|
pub adapter_type: AdapterType,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub enum AdapterType {
|
||||||
|
OpenAICompatible,
|
||||||
|
// Future: Claude, Gemini, etc.
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get provider configuration
|
||||||
|
pub fn get_provider_config(provider_id: &ProviderId) -> ProviderConfig {
|
||||||
match provider_id {
|
match provider_id {
|
||||||
// All these providers currently use OpenAI-compatible chat completions API
|
|
||||||
// In the future, we can add provider-specific handling in separate match arms
|
|
||||||
ProviderId::OpenAI | ProviderId::Groq | ProviderId::Mistral | ProviderId::Deepseek
|
ProviderId::OpenAI | ProviderId::Groq | ProviderId::Mistral | ProviderId::Deepseek
|
||||||
| ProviderId::Arch | ProviderId::Gemini | ProviderId::Claude | ProviderId::GitHub => {
|
| ProviderId::Arch | ProviderId::Gemini | ProviderId::Claude | ProviderId::GitHub => {
|
||||||
|
ProviderConfig {
|
||||||
|
supported_apis: &["/v1/chat/completions"],
|
||||||
|
adapter_type: AdapterType::OpenAICompatible,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Parse request from bytes using provider ID - returns generic ProviderRequest trait object
|
||||||
|
pub fn try_request_from_bytes(bytes: &[u8], provider_id: &ProviderId) -> Result<Box<dyn ProviderRequest>, ProviderRequestError> {
|
||||||
|
let config = get_provider_config(provider_id);
|
||||||
|
|
||||||
|
match config.adapter_type {
|
||||||
|
AdapterType::OpenAICompatible => {
|
||||||
let request = crate::apis::openai::ChatCompletionsRequest::try_from((bytes, provider_id))
|
let request = crate::apis::openai::ChatCompletionsRequest::try_from((bytes, provider_id))
|
||||||
.map_err(|e| ProviderRequestError {
|
.map_err(|e| ProviderRequestError {
|
||||||
message: format!("Failed to parse request: {}", e),
|
message: format!("Failed to parse request: {}", e),
|
||||||
|
|
@ -175,9 +163,10 @@ pub fn try_request_from_bytes(bytes: &[u8], provider_id: &ProviderId) -> Result<
|
||||||
|
|
||||||
/// Parse response from bytes using provider ID - returns generic ProviderResponse trait object
|
/// Parse response from bytes using provider ID - returns generic ProviderResponse trait object
|
||||||
pub fn try_response_from_bytes(bytes: &[u8], provider_id: &ProviderId, _mode: ConversionMode) -> Result<Box<dyn ProviderResponse>, ProviderResponseError> {
|
pub fn try_response_from_bytes(bytes: &[u8], provider_id: &ProviderId, _mode: ConversionMode) -> Result<Box<dyn ProviderResponse>, ProviderResponseError> {
|
||||||
match provider_id {
|
let config = get_provider_config(provider_id);
|
||||||
ProviderId::OpenAI | ProviderId::Groq | ProviderId::Mistral | ProviderId::Deepseek
|
|
||||||
| ProviderId::Arch | ProviderId::Gemini | ProviderId::Claude | ProviderId::GitHub => {
|
match config.adapter_type {
|
||||||
|
AdapterType::OpenAICompatible => {
|
||||||
// Parameterized conversion allows provider-specific response parsing
|
// Parameterized conversion allows provider-specific response parsing
|
||||||
let response = crate::apis::openai::ChatCompletionsResponse::try_from((bytes, provider_id))
|
let response = crate::apis::openai::ChatCompletionsResponse::try_from((bytes, provider_id))
|
||||||
.map_err(|e| ProviderResponseError {
|
.map_err(|e| ProviderResponseError {
|
||||||
|
|
@ -192,13 +181,11 @@ pub fn try_response_from_bytes(bytes: &[u8], provider_id: &ProviderId, _mode: Co
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Create streaming response using provider ID - returns clean ProviderStreamResponseIter trait object
|
/// Create streaming response using provider ID - returns clean ProviderStreamResponseIter trait object
|
||||||
///
|
|
||||||
/// This function returns a ProviderStreamResponseIter that's just an iterator,
|
|
||||||
/// eliminating the complex nested Result<Box<dyn Iterator<...>>> type completely.
|
|
||||||
pub fn try_streaming_from_bytes(bytes: &[u8], provider_id: &ProviderId, _mode: ConversionMode) -> Result<Box<dyn ProviderStreamResponseIter>, Box<dyn std::error::Error + Send + Sync>> {
|
pub fn try_streaming_from_bytes(bytes: &[u8], provider_id: &ProviderId, _mode: ConversionMode) -> Result<Box<dyn ProviderStreamResponseIter>, Box<dyn std::error::Error + Send + Sync>> {
|
||||||
match provider_id {
|
let config = get_provider_config(provider_id);
|
||||||
ProviderId::OpenAI | ProviderId::Groq | ProviderId::Mistral | ProviderId::Deepseek
|
|
||||||
| ProviderId::Arch | ProviderId::Gemini | ProviderId::Claude | ProviderId::GitHub => {
|
match config.adapter_type {
|
||||||
|
AdapterType::OpenAICompatible => {
|
||||||
// Parse SSE (Server-Sent Events) streaming data
|
// Parse SSE (Server-Sent Events) streaming data
|
||||||
let s = std::str::from_utf8(bytes)?;
|
let s = std::str::from_utf8(bytes)?;
|
||||||
let lines: Vec<String> = s.lines().map(|line| line.to_string()).collect();
|
let lines: Vec<String> = s.lines().map(|line| line.to_string()).collect();
|
||||||
|
|
@ -211,29 +198,50 @@ pub fn try_streaming_from_bytes(bytes: &[u8], provider_id: &ProviderId, _mode: C
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Check if provider has compatible API
|
/// Check if provider has compatible API
|
||||||
///
|
|
||||||
/// Replaces the old ProviderInterface::has_compatible_api method.
|
|
||||||
/// This function enables runtime API compatibility checking without needing a provider instance.
|
|
||||||
pub fn has_compatible_api(provider_id: &ProviderId, api_path: &str) -> bool {
|
pub fn has_compatible_api(provider_id: &ProviderId, api_path: &str) -> bool {
|
||||||
match provider_id {
|
let config = get_provider_config(provider_id);
|
||||||
// Currently all these providers support OpenAI chat completions API
|
config.supported_apis.iter().any(|&supported| supported == api_path)
|
||||||
// Future providers with different APIs will get their own match arms
|
|
||||||
ProviderId::OpenAI | ProviderId::Groq | ProviderId::Mistral | ProviderId::Deepseek
|
|
||||||
| ProviderId::Arch | ProviderId::Gemini | ProviderId::Claude | ProviderId::GitHub => {
|
|
||||||
api_path == "/v1/chat/completions"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get supported APIs for provider
|
/// Get supported APIs for provider
|
||||||
///
|
|
||||||
/// Replaces the old ProviderInterface::supported_apis method.
|
|
||||||
/// Returns a static list of supported API endpoints for the given provider.
|
|
||||||
pub fn supported_apis(provider_id: &ProviderId) -> Vec<&'static str> {
|
pub fn supported_apis(provider_id: &ProviderId) -> Vec<&'static str> {
|
||||||
match provider_id {
|
let config = get_provider_config(provider_id);
|
||||||
ProviderId::OpenAI | ProviderId::Groq | ProviderId::Mistral | ProviderId::Deepseek
|
config.supported_apis.to_vec()
|
||||||
| ProviderId::Arch | ProviderId::Gemini | ProviderId::Claude | ProviderId::GitHub => {
|
}
|
||||||
vec!["/v1/chat/completions"]
|
|
||||||
}
|
/// Error types for provider operations
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct ProviderRequestError {
|
||||||
|
pub message: String,
|
||||||
|
pub source: Option<Box<dyn Error + Send + Sync>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct ProviderResponseError {
|
||||||
|
pub message: String,
|
||||||
|
pub source: Option<Box<dyn Error + Send + Sync>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl fmt::Display for ProviderRequestError {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
write!(f, "Provider request error: {}", self.message)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl fmt::Display for ProviderResponseError {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
write!(f, "Provider response error: {}", self.message)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Error for ProviderRequestError {
|
||||||
|
fn source(&self) -> Option<&(dyn Error + 'static)> {
|
||||||
|
self.source.as_ref().map(|e| e.as_ref() as &(dyn Error + 'static))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Error for ProviderResponseError {
|
||||||
|
fn source(&self) -> Option<&(dyn Error + 'static)> {
|
||||||
|
self.source.as_ref().map(|e| e.as_ref() as &(dyn Error + 'static))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -352,17 +352,16 @@ impl HttpContext for StreamContext {
|
||||||
};
|
};
|
||||||
|
|
||||||
// Set the resolved model using the trait method
|
// Set the resolved model using the trait method
|
||||||
deserialized_body.set_model(resolved_model);
|
deserialized_body.set_model(resolved_model.clone());
|
||||||
|
|
||||||
// Extract user message for tracing
|
// Extract user message for tracing
|
||||||
self.user_message = deserialized_body.extract_user_message();
|
self.user_message = deserialized_body.extract_user_message();
|
||||||
|
|
||||||
info!(
|
info!(
|
||||||
"on_http_request_body: provider: {}, model requested (in body): {}, model selected: {}, final model: {}",
|
"on_http_request_body: provider: {}, model requested (in body): {}, model selected: {}",
|
||||||
self.llm_provider().name,
|
self.llm_provider().name,
|
||||||
model_requested,
|
model_requested,
|
||||||
model_name.unwrap_or(&"None".to_string()),
|
model_name.unwrap_or(&"None".to_string()),
|
||||||
deserialized_body.model(),
|
|
||||||
);
|
);
|
||||||
|
|
||||||
// Use provider interface for streaming detection and setup
|
// Use provider interface for streaming detection and setup
|
||||||
|
|
@ -376,7 +375,7 @@ impl HttpContext for StreamContext {
|
||||||
// Use provider interface for text extraction (after potential mutation)
|
// Use provider interface for text extraction (after potential mutation)
|
||||||
let input_tokens_str = deserialized_body.extract_messages_text();
|
let input_tokens_str = deserialized_body.extract_messages_text();
|
||||||
// enforce ratelimits on ingress
|
// enforce ratelimits on ingress
|
||||||
if let Err(e) = self.enforce_ratelimits(&model_requested, input_tokens_str.as_str()) {
|
if let Err(e) = self.enforce_ratelimits(&resolved_model, input_tokens_str.as_str()) {
|
||||||
self.send_server_error(
|
self.send_server_error(
|
||||||
ServerError::ExceededRatelimit(e),
|
ServerError::ExceededRatelimit(e),
|
||||||
Some(StatusCode::TOO_MANY_REQUESTS),
|
Some(StatusCode::TOO_MANY_REQUESTS),
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,7 @@ fn wasm_module() -> String {
|
||||||
wasm_file.exists(),
|
wasm_file.exists(),
|
||||||
"Run `cargo build --release --target=wasm32-wasip1` first"
|
"Run `cargo build --release --target=wasm32-wasip1` first"
|
||||||
);
|
);
|
||||||
wasm_file.to_str().unwrap().to_string()
|
wasm_file.to_string_lossy().to_string()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn request_headers_expectations(module: &mut Tester, http_context: i32) {
|
fn request_headers_expectations(module: &mut Tester, http_context: i32) {
|
||||||
|
|
@ -267,17 +267,12 @@ fn llm_gateway_bad_request_to_open_ai_chat_completions() {
|
||||||
.returning(Some(incomplete_chat_completions_request_body))
|
.returning(Some(incomplete_chat_completions_request_body))
|
||||||
.expect_log(Some(LogLevel::Debug), None)
|
.expect_log(Some(LogLevel::Debug), None)
|
||||||
.expect_log(Some(LogLevel::Info), Some("on_http_request_body: provider: open-ai-gpt-4, model requested (in body): gpt-1, model selected: gpt-4"))
|
.expect_log(Some(LogLevel::Info), Some("on_http_request_body: provider: open-ai-gpt-4, model requested (in body): gpt-1, model selected: gpt-4"))
|
||||||
.expect_send_local_response(
|
.expect_log(Some(LogLevel::Debug), Some("getting token count model=gpt-4"))
|
||||||
Some(StatusCode::BAD_REQUEST.as_u16().into()),
|
.expect_log(Some(LogLevel::Debug), Some("Recorded input token count: 13"))
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
.expect_log(Some(LogLevel::Debug), None)
|
|
||||||
.expect_log(Some(LogLevel::Debug), None)
|
|
||||||
.expect_log(Some(LogLevel::Debug), None)
|
|
||||||
.expect_metric_record("input_sequence_length", 13)
|
.expect_metric_record("input_sequence_length", 13)
|
||||||
.expect_log(Some(LogLevel::Debug), None)
|
.expect_log(Some(LogLevel::Debug), Some("Applying ratelimit for model: gpt-4"))
|
||||||
|
.expect_log(Some(LogLevel::Debug), Some(r#"Checking limit for provider=gpt-4, with selector=Header { key: "selector-key", value: "selector-value" }, consuming tokens=13"#))
|
||||||
|
.expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None)
|
||||||
.execute_and_expect(ReturnType::Action(Action::Continue))
|
.execute_and_expect(ReturnType::Action(Action::Continue))
|
||||||
.unwrap();
|
.unwrap();
|
||||||
}
|
}
|
||||||
|
|
@ -386,11 +381,11 @@ fn llm_gateway_request_not_ratelimited() {
|
||||||
.returning(Some(chat_completions_request_body))
|
.returning(Some(chat_completions_request_body))
|
||||||
// The actual call is not important in this test, we just need to grab the token_id
|
// The actual call is not important in this test, we just need to grab the token_id
|
||||||
.expect_log(Some(LogLevel::Info), None)
|
.expect_log(Some(LogLevel::Info), None)
|
||||||
.expect_log(Some(LogLevel::Debug), None)
|
.expect_log(Some(LogLevel::Debug), Some("getting token count model=gpt-4"))
|
||||||
.expect_log(Some(LogLevel::Debug), None)
|
.expect_log(Some(LogLevel::Debug), Some("Recorded input token count: 29"))
|
||||||
.expect_metric_record("input_sequence_length", 29)
|
.expect_metric_record("input_sequence_length", 29)
|
||||||
.expect_log(Some(LogLevel::Debug), None)
|
.expect_log(Some(LogLevel::Debug), Some("Applying ratelimit for model: gpt-4"))
|
||||||
.expect_log(Some(LogLevel::Debug), None)
|
.expect_log(Some(LogLevel::Debug), Some(r#"Checking limit for provider=gpt-4, with selector=Header { key: "selector-key", value: "selector-value" }, consuming tokens=29"#))
|
||||||
.expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None)
|
.expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None)
|
||||||
.execute_and_expect(ReturnType::Action(Action::Continue))
|
.execute_and_expect(ReturnType::Action(Action::Continue))
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
@ -433,11 +428,11 @@ fn llm_gateway_override_model_name() {
|
||||||
// The actual call is not important in this test, we just need to grab the token_id
|
// The actual call is not important in this test, we just need to grab the token_id
|
||||||
.expect_log(Some(LogLevel::Debug), None)
|
.expect_log(Some(LogLevel::Debug), None)
|
||||||
.expect_log(Some(LogLevel::Info), Some("on_http_request_body: provider: open-ai-gpt-4, model requested (in body): gpt-1, model selected: gpt-4"))
|
.expect_log(Some(LogLevel::Info), Some("on_http_request_body: provider: open-ai-gpt-4, model requested (in body): gpt-1, model selected: gpt-4"))
|
||||||
.expect_log(Some(LogLevel::Debug), None)
|
.expect_log(Some(LogLevel::Debug), Some("getting token count model=gpt-4"))
|
||||||
.expect_log(Some(LogLevel::Debug), None)
|
.expect_log(Some(LogLevel::Debug), Some("Recorded input token count: 29"))
|
||||||
.expect_metric_record("input_sequence_length", 29)
|
.expect_metric_record("input_sequence_length", 29)
|
||||||
.expect_log(Some(LogLevel::Debug), None)
|
.expect_log(Some(LogLevel::Debug), Some("Applying ratelimit for model: gpt-4"))
|
||||||
.expect_log(Some(LogLevel::Debug), None)
|
.expect_log(Some(LogLevel::Debug), Some(r#"Checking limit for provider=gpt-4, with selector=Header { key: "selector-key", value: "selector-value" }, consuming tokens=29"#))
|
||||||
.expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None)
|
.expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None)
|
||||||
.execute_and_expect(ReturnType::Action(Action::Continue))
|
.execute_and_expect(ReturnType::Action(Action::Continue))
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
@ -483,8 +478,8 @@ fn llm_gateway_override_use_default_model() {
|
||||||
Some(LogLevel::Info),
|
Some(LogLevel::Info),
|
||||||
Some("on_http_request_body: provider: open-ai-gpt-4, model requested (in body): gpt-1, model selected: gpt-4"),
|
Some("on_http_request_body: provider: open-ai-gpt-4, model requested (in body): gpt-1, model selected: gpt-4"),
|
||||||
)
|
)
|
||||||
.expect_log(Some(LogLevel::Debug), None)
|
.expect_log(Some(LogLevel::Debug), Some("getting token count model=gpt-4"))
|
||||||
.expect_log(Some(LogLevel::Debug), None)
|
.expect_log(Some(LogLevel::Debug), Some("Recorded input token count: 29"))
|
||||||
.expect_metric_record("input_sequence_length", 29)
|
.expect_metric_record("input_sequence_length", 29)
|
||||||
.expect_log(Some(LogLevel::Debug), Some("Applying ratelimit for model: gpt-4"))
|
.expect_log(Some(LogLevel::Debug), Some("Applying ratelimit for model: gpt-4"))
|
||||||
.expect_log(Some(LogLevel::Debug), Some(r#"Checking limit for provider=gpt-4, with selector=Header { key: "selector-key", value: "selector-value" }, consuming tokens=29"#))
|
.expect_log(Some(LogLevel::Debug), Some(r#"Checking limit for provider=gpt-4, with selector=Header { key: "selector-key", value: "selector-value" }, consuming tokens=29"#))
|
||||||
|
|
@ -530,11 +525,11 @@ fn llm_gateway_override_use_model_name_none() {
|
||||||
// The actual call is not important in this test, we just need to grab the token_id
|
// The actual call is not important in this test, we just need to grab the token_id
|
||||||
.expect_log(Some(LogLevel::Debug), None)
|
.expect_log(Some(LogLevel::Debug), None)
|
||||||
.expect_log(Some(LogLevel::Info), Some("on_http_request_body: provider: open-ai-gpt-4, model requested (in body): none, model selected: gpt-4"))
|
.expect_log(Some(LogLevel::Info), Some("on_http_request_body: provider: open-ai-gpt-4, model requested (in body): none, model selected: gpt-4"))
|
||||||
.expect_log(Some(LogLevel::Debug), None)
|
.expect_log(Some(LogLevel::Debug), Some("getting token count model=gpt-4"))
|
||||||
.expect_log(Some(LogLevel::Debug), None)
|
.expect_log(Some(LogLevel::Debug), Some("Recorded input token count: 29"))
|
||||||
.expect_metric_record("input_sequence_length", 29)
|
.expect_metric_record("input_sequence_length", 29)
|
||||||
.expect_log(Some(LogLevel::Debug), None)
|
.expect_log(Some(LogLevel::Debug), Some("Applying ratelimit for model: gpt-4"))
|
||||||
.expect_log(Some(LogLevel::Debug), None)
|
.expect_log(Some(LogLevel::Debug), Some(r#"Checking limit for provider=gpt-4, with selector=Header { key: "selector-key", value: "selector-value" }, consuming tokens=29"#))
|
||||||
.expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None)
|
.expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None)
|
||||||
.execute_and_expect(ReturnType::Action(Action::Continue))
|
.execute_and_expect(ReturnType::Action(Action::Continue))
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue