diff --git a/crates/brightstaff/src/handlers/chat_completions.rs b/crates/brightstaff/src/handlers/chat_completions.rs index 37da961f..d0e5910a 100644 --- a/crates/brightstaff/src/handlers/chat_completions.rs +++ b/crates/brightstaff/src/handlers/chat_completions.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use bytes::Bytes; use common::configuration::ModelUsagePreference; use common::consts::ARCH_PROVIDER_HINT_HEADER; -use hermesllm::providers::openai::types::ChatCompletionsRequest; +use hermesllm::apis::openai::ChatCompletionsRequest; use http_body_util::combinators::BoxBody; use http_body_util::{BodyExt, Full, StreamBody}; use hyper::body::Frame; @@ -93,7 +93,7 @@ pub async fn chat_completions( chat_completion_request.metadata.and_then(|metadata| { metadata .get("archgw_preference_config") - .and_then(|value| value.as_str().map(String::from)) + .map(|value| value.to_string()) }); let usage_preferences: Option> = usage_preferences_str @@ -105,9 +105,7 @@ pub async fn chat_completions( .messages .last() .map_or("None".to_string(), |msg| { - msg.content.as_ref().map_or("None".to_string(), |content| { - content.to_string().replace('\n', "\\n") - }) + msg.content.to_string().replace('\n', "\\n") }); const MAX_MESSAGE_LENGTH: usize = 50; diff --git a/crates/brightstaff/src/handlers/models.rs b/crates/brightstaff/src/handlers/models.rs index 3a4662a6..ac1bbebe 100644 --- a/crates/brightstaff/src/handlers/models.rs +++ b/crates/brightstaff/src/handlers/models.rs @@ -1,6 +1,6 @@ use bytes::Bytes; use common::configuration::{IntoModels, LlmProvider}; -use hermesllm::providers::openai::types::Models; +use hermesllm::apis::openai::Models; use http_body_util::{combinators::BoxBody, BodyExt, Full}; use hyper::{Response, StatusCode}; use serde_json; diff --git a/crates/brightstaff/src/main.rs b/crates/brightstaff/src/main.rs index b5bf0204..34fa3aa3 100644 --- a/crates/brightstaff/src/main.rs +++ b/crates/brightstaff/src/main.rs @@ -98,7 +98,7 @@ async fn main() -> Result<(), Box> { let peer_addr = stream.peer_addr()?; let io = TokioIo::new(stream); - let router_service = Arc::clone(&router_service); + let router_service: Arc = Arc::clone(&router_service); let llm_provider_endpoint = llm_provider_endpoint.clone(); let llm_providers = llm_providers.clone(); diff --git a/crates/brightstaff/src/router/llm_router.rs b/crates/brightstaff/src/router/llm_router.rs index fc6d9365..3b09c115 100644 --- a/crates/brightstaff/src/router/llm_router.rs +++ b/crates/brightstaff/src/router/llm_router.rs @@ -4,7 +4,7 @@ use common::{ configuration::{LlmProvider, ModelUsagePreference, RoutingPreference}, consts::ARCH_PROVIDER_HINT_HEADER, }; -use hermesllm::providers::openai::types::{ChatCompletionsResponse, ContentType, Message}; +use hermesllm::apis::openai::{ChatCompletionsResponse, Message}; use hyper::header; use thiserror::Error; use tracing::{debug, info, warn}; @@ -153,9 +153,7 @@ impl RouterService { return Ok(None); } - if let Some(ContentType::Text(content)) = - &chat_completion_response.choices[0].message.content - { + if let Some(content) = &chat_completion_response.choices[0].message.content { let parsed_response = self .router_model .parse_response(content, &usage_preferences)?; diff --git a/crates/brightstaff/src/router/router_model.rs b/crates/brightstaff/src/router/router_model.rs index ec0c1a1f..372907af 100644 --- a/crates/brightstaff/src/router/router_model.rs +++ b/crates/brightstaff/src/router/router_model.rs @@ -1,5 +1,5 @@ use common::configuration::ModelUsagePreference; -use hermesllm::providers::openai::types::{ChatCompletionsRequest, Message}; +use hermesllm::apis::openai::{ChatCompletionsRequest, Message}; use thiserror::Error; #[derive(Debug, Error)] diff --git a/crates/brightstaff/src/router/router_model_v1.rs b/crates/brightstaff/src/router/router_model_v1.rs index bd06b525..1c1c14ef 100644 --- a/crates/brightstaff/src/router/router_model_v1.rs +++ b/crates/brightstaff/src/router/router_model_v1.rs @@ -2,9 +2,8 @@ use std::collections::HashMap; use common::{ configuration::{ModelUsagePreference, RoutingPreference}, - consts::{SYSTEM_ROLE, TOOL_ROLE, USER_ROLE}, }; -use hermesllm::providers::openai::types::{ChatCompletionsRequest, ContentType, Message}; +use hermesllm::apis::openai::{ChatCompletionsRequest, MessageContent, Message, Role}; use serde::{Deserialize, Serialize}; use tracing::{debug, warn}; @@ -80,7 +79,9 @@ impl RouterModel for RouterModelV1 { // when role == tool its tool call response let messages_vec = messages .iter() - .filter(|m| m.role != SYSTEM_ROLE && m.role != TOOL_ROLE && m.content.is_some()) + .filter(|m| { + m.role != Role::System && m.role != Role::Tool && !m.content.to_string().is_empty() + }) .collect::>(); // Following code is to ensure that the conversation does not exceed max token length @@ -88,13 +89,7 @@ impl RouterModel for RouterModelV1 { let mut token_count = ARCH_ROUTER_V1_SYSTEM_PROMPT.len() / TOKEN_LENGTH_DIVISOR; let mut selected_messages_list_reversed: Vec<&Message> = vec![]; for (selected_messsage_count, message) in messages_vec.iter().rev().enumerate() { - let message_token_count = message - .content - .as_ref() - .unwrap_or(&ContentType::Text("".to_string())) - .to_string() - .len() - / TOKEN_LENGTH_DIVISOR; + let message_token_count = message.content.to_string().len() / TOKEN_LENGTH_DIVISOR; token_count += message_token_count; if token_count > self.max_token_length { debug!( @@ -104,7 +99,7 @@ impl RouterModel for RouterModelV1 { , selected_messsage_count, messages_vec.len() ); - if message.role == USER_ROLE { + if message.role == Role::User { // If message that exceeds max token length is from user, we need to keep it selected_messages_list_reversed.push(message); } @@ -125,12 +120,12 @@ impl RouterModel for RouterModelV1 { // ensure that first and last selected message is from user if let Some(first_message) = selected_messages_list_reversed.first() { - if first_message.role != USER_ROLE { + if first_message.role != Role::User { warn!("RouterModelV1: last message in the conversation is not from user, this may lead to incorrect routing"); } } if let Some(last_message) = selected_messages_list_reversed.last() { - if last_message.role != USER_ROLE { + if last_message.role != Role::User { warn!("RouterModelV1: first message in the conversation is not from user, this may lead to incorrect routing"); } } @@ -143,9 +138,10 @@ impl RouterModel for RouterModelV1 { Message { role: message.role.clone(), // we can unwrap here because we have already filtered out messages without content - content: Some(ContentType::Text( - message.content.as_ref().unwrap().to_string(), - )), + content: MessageContent::Text(message.content.to_string()), + name: None, + tool_calls: None, + tool_call_id: None, } }) .collect::>(); @@ -160,8 +156,11 @@ impl RouterModel for RouterModelV1 { ChatCompletionsRequest { model: self.routing_model.clone(), messages: vec![Message { - content: Some(ContentType::Text(router_message)), - role: USER_ROLE.to_string(), + content: MessageContent::Text(router_message), + role: Role::User, + name: None, + tool_calls: None, + tool_call_id: None, }], temperature: Some(0.01), ..Default::default() @@ -347,9 +346,9 @@ Based on your analysis, provide your response in the following JSON formats if y let req = router.generate_request(&conversation, &None); - let prompt = req.messages[0].content.as_ref().unwrap(); + let prompt = req.messages[0].content.to_string(); - assert_eq!(expected_prompt, prompt.to_string()); + assert_eq!(expected_prompt, prompt); } #[test] @@ -412,9 +411,9 @@ Based on your analysis, provide your response in the following JSON formats if y }]); let req = router.generate_request(&conversation, &usage_preferences); - let prompt = req.messages[0].content.as_ref().unwrap(); + let prompt = req.messages[0].content.to_string(); - assert_eq!(expected_prompt, prompt.to_string()); + assert_eq!(expected_prompt, prompt); } #[test] @@ -472,9 +471,9 @@ Based on your analysis, provide your response in the following JSON formats if y let req = router.generate_request(&conversation, &None); - let prompt = req.messages[0].content.as_ref().unwrap(); + let prompt = req.messages[0].content.to_string(); - assert_eq!(expected_prompt, prompt.to_string()); + assert_eq!(expected_prompt, prompt); } #[test] @@ -533,9 +532,9 @@ Based on your analysis, provide your response in the following JSON formats if y let req = router.generate_request(&conversation, &None); - let prompt = req.messages[0].content.as_ref().unwrap(); + let prompt = req.messages[0].content.to_string(); - assert_eq!(expected_prompt, prompt.to_string()); + assert_eq!(expected_prompt, prompt); } #[test] @@ -601,9 +600,9 @@ Based on your analysis, provide your response in the following JSON formats if y let req = router.generate_request(&conversation, &None); - let prompt = req.messages[0].content.as_ref().unwrap(); + let prompt = req.messages[0].content.to_string(); - assert_eq!(expected_prompt, prompt.to_string()); + assert_eq!(expected_prompt, prompt); } #[test] @@ -670,9 +669,9 @@ Based on your analysis, provide your response in the following JSON formats if y let req = router.generate_request(&conversation, &None); - let prompt = req.messages[0].content.as_ref().unwrap(); + let prompt = req.messages[0].content.to_string(); - assert_eq!(expected_prompt, prompt.to_string()); + assert_eq!(expected_prompt, prompt); } #[test] @@ -716,14 +715,14 @@ Based on your analysis, provide your response in the following JSON formats if y }, { "role": "assistant", - "content": null, + "content": "", "tool_calls": [ { "id": "toolcall-abc123", "type": "function", "function": { "name": "get_weather", - "arguments": { "location": "Tokyo" } + "arguments": "{ \"location\": \"Tokyo\" }" } } ] @@ -763,11 +762,11 @@ Based on your analysis, provide your response in the following JSON formats if y let conversation: Vec = serde_json::from_str(conversation_str).unwrap(); - let req = router.generate_request(&conversation, &None); + let req: ChatCompletionsRequest = router.generate_request(&conversation, &None); - let prompt = req.messages[0].content.as_ref().unwrap(); + let prompt = req.messages[0].content.to_string(); - assert_eq!(expected_prompt, prompt.to_string()); + assert_eq!(expected_prompt, prompt); } #[test] diff --git a/crates/common/src/configuration.rs b/crates/common/src/configuration.rs index 009e3c53..20d2623b 100644 --- a/crates/common/src/configuration.rs +++ b/crates/common/src/configuration.rs @@ -1,4 +1,4 @@ -use hermesllm::providers::openai::types::{ModelDetail, ModelObject, Models}; +use hermesllm::apis::openai::{ModelDetail, ModelObject, Models}; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::fmt::Display; diff --git a/crates/common/src/errors.rs b/crates/common/src/errors.rs index 582c0a7c..21af3c94 100644 --- a/crates/common/src/errors.rs +++ b/crates/common/src/errors.rs @@ -1,7 +1,7 @@ use proxy_wasm::types::Status; use crate::{api::open_ai::ChatCompletionChunkResponseError, ratelimit}; -use hermesllm::providers::openai::types::OpenAIError; +use hermesllm::apis::openai::OpenAIError; #[derive(thiserror::Error, Debug)] pub enum ClientError { diff --git a/crates/hermesllm/src/apis/openai.rs b/crates/hermesllm/src/apis/openai.rs index 8a18a441..15c77bed 100644 --- a/crates/hermesllm/src/apis/openai.rs +++ b/crates/hermesllm/src/apis/openai.rs @@ -2,6 +2,8 @@ use serde::{Deserialize, Serialize}; use serde_json::Value; use serde_with::skip_serializing_none; use std::collections::HashMap; +use std::fmt::Display; +use thiserror::Error; use crate::{providers::ProviderRequestError, ConversionMode, ProviderRequest}; use super::ApiDefinition; @@ -116,8 +118,8 @@ pub enum Role { #[skip_serializing_none] #[derive(Serialize, Deserialize, Debug, Clone)] pub struct Message { - pub content: MessageContent, pub role: Role, + pub content: MessageContent, pub name: Option, /// Tool calls made by the assistant (only present for assistant role) pub tool_calls: Option>, @@ -171,6 +173,28 @@ pub enum MessageContent { Parts(Vec), } +impl Display for MessageContent { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + MessageContent::Text(text) => write!(f, "{}", text), + MessageContent::Parts(parts) => { + let text_parts: Vec = parts + .iter() + .filter_map(|part| match part { + ContentPart::Text { text } => Some(text.clone()), + ContentPart::ImageUrl { .. } => { + // skip image URLs or their data in text representation + None + } + }) + .collect(); + let combined_text = text_parts.join("\n"); + write!(f, "{}", combined_text) + } + } + } +} + /// Individual content part within a message (text or image) #[derive(Serialize, Deserialize, Debug, Clone)] #[serde(tag = "type")] @@ -560,6 +584,26 @@ impl TokenUsage for Usage { } } +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ModelDetail { + pub id: String, + pub object: String, + pub created: usize, + pub owned_by: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum ModelObject { + #[serde(rename = "list")] + List, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Models { + pub object: ModelObject, + pub data: Vec, +} + // Error type for streaming operations #[derive(Debug, thiserror::Error)] pub enum OpenAIStreamError { @@ -571,6 +615,22 @@ pub enum OpenAIStreamError { InvalidStreamingData(String), } +#[derive(Debug, Error)] +pub enum OpenAIError { + #[error("json error: {0}")] + JsonParseError(#[from] serde_json::Error), + #[error("utf8 parsing error: {0}")] + Utf8Error(#[from] std::str::Utf8Error), + #[error("invalid streaming data err {source}, data: {data}")] + InvalidStreamingData { + source: serde_json::Error, + data: String, + }, + #[error("unsupported provider: {provider}")] + UnsupportedProvider { provider: String }, +} + + /// SSE-based streaming iterator for OpenAI chat completions /// Implements ProviderStreamResponseIter directly pub struct SseChatCompletionIter diff --git a/crates/hermesllm/src/providers/openai/builder.rs b/crates/hermesllm/src/providers/openai/builder.rs deleted file mode 100644 index fa1f325e..00000000 --- a/crates/hermesllm/src/providers/openai/builder.rs +++ /dev/null @@ -1,114 +0,0 @@ -use serde_json::Value; - -use crate::providers::openai::types::{ChatCompletionsRequest, Message, StreamOptions}; - -#[derive(Debug, Clone)] -pub struct OpenAIRequestBuilder { - model: String, - messages: Vec, - temperature: Option, - top_p: Option, - n: Option, - max_tokens: Option, - stream: Option, - stop: Option>, - presence_penalty: Option, - frequency_penalty: Option, - stream_options: Option, - tools: Option>, -} - -impl OpenAIRequestBuilder { - pub fn new(model: impl Into, messages: Vec) -> Self { - Self { - model: model.into(), - messages, - temperature: None, - top_p: None, - n: None, - max_tokens: None, - stream: None, - stop: None, - presence_penalty: None, - frequency_penalty: None, - stream_options: None, - tools: None, - } - } - - pub fn temperature(mut self, temperature: f32) -> Self { - self.temperature = Some(temperature); - self - } - - pub fn top_p(mut self, top_p: f32) -> Self { - self.top_p = Some(top_p); - self - } - - pub fn n(mut self, n: u32) -> Self { - self.n = Some(n); - self - } - - pub fn max_tokens(mut self, max_tokens: u32) -> Self { - self.max_tokens = Some(max_tokens); - self - } - - pub fn stream(mut self, stream: bool) -> Self { - self.stream = Some(stream); - self - } - - pub fn stop(mut self, stop: Vec) -> Self { - self.stop = Some(stop); - self - } - - pub fn presence_penalty(mut self, presence_penalty: f32) -> Self { - self.presence_penalty = Some(presence_penalty); - self - } - - pub fn frequency_penalty(mut self, frequency_penalty: f32) -> Self { - self.frequency_penalty = Some(frequency_penalty); - self - } - - pub fn stream_options(mut self, include_usage: bool) -> Self { - self.stream = Some(true); - self.stream_options = Some(StreamOptions { include_usage }); - self - } - - pub fn tools(mut self, tools: Vec) -> Self { - self.tools = Some(tools); - self - } - - pub fn build(self) -> Result { - let request = ChatCompletionsRequest { - model: self.model, - messages: self.messages, - temperature: self.temperature, - top_p: self.top_p, - n: self.n, - max_tokens: self.max_tokens, - stream: self.stream, - stop: self.stop, - presence_penalty: self.presence_penalty, - frequency_penalty: self.frequency_penalty, - stream_options: self.stream_options, - tools: self.tools, - metadata: None, - }; - Ok(request) - } -} - -impl ChatCompletionsRequest { - pub fn builder(model: impl Into, messages: Vec) -> OpenAIRequestBuilder { - OpenAIRequestBuilder::new(model, messages) - } -} diff --git a/crates/hermesllm/src/providers/openai/mod.rs b/crates/hermesllm/src/providers/openai/mod.rs index ddab6913..d82d5ab0 100644 --- a/crates/hermesllm/src/providers/openai/mod.rs +++ b/crates/hermesllm/src/providers/openai/mod.rs @@ -1,10 +1,5 @@ -pub mod builder; -pub mod types; - // Re-export the main types and builder functionality pub use crate::apis::openai::{ChatCompletionsRequest, ChatCompletionsResponse, ChatCompletionsStreamResponse}; -pub use builder::*; -pub use types::*; // Note: The OpenAIProvider struct has been deprecated in favor of the function-based approach in traits.rs // All provider functionality is now accessed through try_request_from_bytes, try_response_from_bytes, etc. diff --git a/crates/hermesllm/src/providers/openai/types.rs b/crates/hermesllm/src/providers/openai/types.rs deleted file mode 100644 index 4dcd7d22..00000000 --- a/crates/hermesllm/src/providers/openai/types.rs +++ /dev/null @@ -1,552 +0,0 @@ -use std::collections::HashMap; -use std::fmt::Display; - -use serde::{Deserialize, Serialize}; -use serde_json::Value; -use serde_with::skip_serializing_none; -use std::convert::TryFrom; -use std::str; -use thiserror::Error; - -use crate::providers::ProviderId; - -#[derive(Debug, Error)] -pub enum OpenAIError { - #[error("json error: {0}")] - JsonParseError(#[from] serde_json::Error), - #[error("utf8 parsing error: {0}")] - Utf8Error(#[from] std::str::Utf8Error), - #[error("invalid streaming data err {source}, data: {data}")] - InvalidStreamingData { - source: serde_json::Error, - data: String, - }, - #[error("unsupported provider: {provider}")] - UnsupportedProvider { provider: String }, -} - -type Result = std::result::Result; - -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -pub enum MultiPartContentType { - #[serde(rename = "text")] - Text, - #[serde(rename = "image_url")] - ImageUrl, -} - -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -pub struct ImageUrl { - pub url: String, -} - -#[skip_serializing_none] -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -pub struct MultiPartContent { - pub text: Option, - pub image_url: Option, - #[serde(rename = "type")] - pub content_type: MultiPartContentType, -} - -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -#[serde(untagged)] -pub enum ContentType { - Text(String), - MultiPart(Vec), -} - -impl Display for ContentType { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - ContentType::Text(text) => write!(f, "{}", text), - ContentType::MultiPart(multi_part) => { - let text_parts: Vec = multi_part - .iter() - .filter_map(|part| { - if part.content_type == MultiPartContentType::Text { - part.text.clone() - } else if part.content_type == MultiPartContentType::ImageUrl { - // skip image URLs or their data in text representation - None - } else { - panic!("Unsupported content type: {:?}", part.content_type); - } - }) - .collect(); - let combined_text = text_parts.join("\n"); - write!(f, "{}", combined_text) - } - } - } -} - -#[skip_serializing_none] -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Message { - pub role: String, - pub content: Option, -} - -impl Message { - pub fn new(content: String) -> Self { - Self { - role: "user".to_string(), - content: Some(ContentType::Text(content)), - } - } -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct StreamOptions { - pub include_usage: bool, -} - -#[skip_serializing_none] -#[derive(Debug, Clone, Serialize, Deserialize, Default)] -pub struct ChatCompletionsRequest { - pub model: String, - pub messages: Vec, - pub temperature: Option, - pub top_p: Option, - pub n: Option, - pub max_tokens: Option, - pub stream: Option, - pub stop: Option>, - pub presence_penalty: Option, - pub frequency_penalty: Option, - pub stream_options: Option, - pub tools: Option>, - pub metadata: Option>, -} - -impl TryFrom<&[u8]> for ChatCompletionsRequest { - type Error = OpenAIError; - fn try_from(bytes: &[u8]) -> Result { - serde_json::from_slice(bytes).map_err(OpenAIError::from) - } -} - -#[skip_serializing_none] -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct ChatCompletionsResponse { - pub id: String, - pub object: String, - pub created: u64, - pub choices: Vec, - pub usage: Option, -} - -impl TryFrom<&[u8]> for ChatCompletionsResponse { - type Error = OpenAIError; - fn try_from(bytes: &[u8]) -> Result { - serde_json::from_slice(bytes).map_err(OpenAIError::from) - } -} - -impl<'a> TryFrom<(&'a [u8], &'a ProviderId)> for ChatCompletionsResponse { - type Error = OpenAIError; - - fn try_from(input: (&'a [u8], &'a ProviderId)) -> Result { - // Use input.provider as needed, if necessary - serde_json::from_slice(input.0).map_err(OpenAIError::from) - } -} - -impl ChatCompletionsRequest { - pub fn to_bytes(&self, provider: ProviderId) -> Result> { - match provider { - ProviderId::OpenAI - | ProviderId::Arch - | ProviderId::Deepseek - | ProviderId::Mistral - | ProviderId::Groq - | ProviderId::Gemini - | ProviderId::Claude - | ProviderId::GitHub => serde_json::to_vec(self).map_err(OpenAIError::from), - } - } -} - -#[skip_serializing_none] -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct Choice { - pub index: u32, - pub message: Message, - pub finish_reason: Option, -} - -#[skip_serializing_none] -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct Usage { - pub prompt_tokens: usize, - pub completion_tokens: usize, - pub total_tokens: usize, -} - -#[skip_serializing_none] -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct DeltaMessage { - pub role: Option, - pub content: Option, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct StreamChoice { - pub index: u32, - pub delta: DeltaMessage, - pub finish_reason: Option, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct ChatCompletionStreamResponse { - pub id: String, - pub object: String, - pub created: u64, - pub model: String, - pub choices: Vec, - pub usage: Option, -} - -pub struct SseChatCompletionIter -where - I: Iterator, - I::Item: AsRef, -{ - lines: I, -} - -impl SseChatCompletionIter -where - I: Iterator, - I::Item: AsRef, -{ - pub fn new(lines: I) -> Self { - Self { lines } - } -} - -impl Iterator for SseChatCompletionIter -where - I: Iterator, - I::Item: AsRef, -{ - type Item = Result; - - fn next(&mut self) -> Option { - for line in &mut self.lines { - let line = line.as_ref(); - if let Some(data) = line.strip_prefix("data: ") { - let data = data.trim(); - if data == "[DONE]" { - return None; - } - - if data == r#"{"type": "ping"}"# { - continue; // Skip ping messages - that is usually from anthropic - } - - return Some( - serde_json::from_str::(data).map_err(|e| { - OpenAIError::InvalidStreamingData { - source: e, - data: data.to_string(), - } - }), - ); - } - } - None - } -} - - -impl<'a> TryFrom<&'a [u8]> for SseChatCompletionIter> { - type Error = OpenAIError; - - fn try_from(bytes: &'a [u8]) -> Result { - let s = std::str::from_utf8(bytes)?; - Ok(SseChatCompletionIter::new(s.lines())) - } -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ModelDetail { - pub id: String, - pub object: String, - pub created: usize, - pub owned_by: String, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub enum ModelObject { - #[serde(rename = "list")] - List, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Models { - pub object: ModelObject, - pub data: Vec, -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_content_type_display() { - let text_content = ContentType::Text("Hello, world!".to_string()); - assert_eq!(text_content.to_string(), "Hello, world!"); - - let multi_part_content = ContentType::MultiPart(vec![ - MultiPartContent { - text: Some("This is a text part.".to_string()), - content_type: MultiPartContentType::Text, - image_url: None, - }, - MultiPartContent { - text: Some("https://example.com/image.png".to_string()), - content_type: MultiPartContentType::ImageUrl, - image_url: None, - }, - ]); - assert_eq!(multi_part_content.to_string(), "This is a text part."); - } - - #[test] - fn test_chat_completions_request_text_type_array() { - const CHAT_COMPLETIONS_REQUEST: &str = r#" - { - "model": "gpt-3.5-turbo", - "messages": [ - { - "role": "user", - "content": [ - { - "type": "text", - "text": "What city do you want to know the weather for?" - }, - { - "type": "text", - "text": "hello world" - } - ] - } - ] - } - "#; - - let chat_completions_request: ChatCompletionsRequest = - serde_json::from_str(CHAT_COMPLETIONS_REQUEST).unwrap(); - assert_eq!(chat_completions_request.model, "gpt-3.5-turbo"); - if let Some(ContentType::MultiPart(multi_part_content)) = - chat_completions_request.messages[0].content.as_ref() - { - assert_eq!(multi_part_content.len(), 2); - assert_eq!( - multi_part_content[0].content_type, - MultiPartContentType::Text - ); - assert_eq!( - multi_part_content[0].text, - Some("What city do you want to know the weather for?".to_string()) - ); - assert_eq!( - multi_part_content[1].content_type, - MultiPartContentType::Text - ); - assert_eq!(multi_part_content[1].text, Some("hello world".to_string())); - } else { - panic!("Expected MultiPartContent"); - } - } - - #[test] - fn test_chat_completions_request_image_content() { - const CHAT_COMPLETIONS_REQUEST: &str = r#" - { - "stream": true, - "model": "openai/gpt-4o", - "messages": [ - { - "role": "user", - "content": [ - { - "type": "text", - "text": "describe this photo pls" - }, - { - "type": "image_url", - "image_url": { - "url": "data:image/jpeg;base64,/9j/4AAQSkZJRgABAQAAAQABAAD/...==" - } - } - ] - } - ] - }"#; - - let chat_completions_request: ChatCompletionsRequest = - serde_json::from_str(CHAT_COMPLETIONS_REQUEST).unwrap(); - assert_eq!(chat_completions_request.model, "openai/gpt-4o"); - if let Some(ContentType::MultiPart(multi_part_content)) = - chat_completions_request.messages[0].content.as_ref() - { - assert_eq!(multi_part_content.len(), 2); - assert_eq!( - multi_part_content[0].content_type, - MultiPartContentType::Text - ); - assert_eq!( - multi_part_content[0].text, - Some("describe this photo pls".to_string()) - ); - assert_eq!( - multi_part_content[1].content_type, - MultiPartContentType::ImageUrl - ); - assert_eq!( - multi_part_content[1].image_url, - Some(ImageUrl { - url: "data:image/jpeg;base64,/9j/4AAQSkZJRgABAQAAAQABAAD/...==".to_string(), - }) - ); - } else { - panic!("Expected MultiPartContent"); - } - } - - #[test] - fn test_sse_streaming() { - let json_data = r#"data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1700000000,"model":"gpt-3.5-turbo","choices":[{"index":0,"delta":{"role":"assistant"},"finish_reason":null}]} -data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1700000000,"model":"gpt-3.5-turbo","choices":[{"index":0,"delta":{"content":"Hello, how can I help you today?"},"finish_reason":null}]} -data: [DONE]"#; - - let iter = SseChatCompletionIter::new(json_data.lines()); - - println!("Testing SSE Streaming"); - for item in iter { - match item { - Ok(response) => { - println!("Received response: {:?}", response); - if response.choices.is_empty() { - continue; - } - for choice in response.choices { - if let Some(content) = choice.delta.content { - println!("Content: {}", content); - } - } - } - Err(e) => { - println!("Error parsing JSON: {}", e); - return; - } - } - } - } - - #[test] - fn test_sse_streaming_try_from_bytes() { - let json_data = r#"data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1700000000,"model":"gpt-3.5-turbo","choices":[{"index":0,"delta":{"role":"assistant"},"finish_reason":null}]} -data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1700000000,"model":"gpt-3.5-turbo","choices":[{"index":0,"delta":{"content":"Hello, how can I help you today?"},"finish_reason":null}]} -data: [DONE]"#; - - let iter = SseChatCompletionIter::try_from(json_data.as_bytes()) - .expect("Failed to create SSE iterator"); - - println!("Testing SSE Streaming"); - for item in iter { - match item { - Ok(response) => { - println!("Received response: {:?}", response); - if response.choices.is_empty() { - continue; - } - for choice in response.choices { - if let Some(content) = choice.delta.content { - println!("Content: {}", content); - } - } - } - Err(e) => { - println!("Error parsing JSON: {}", e); - return; - } - } - } - } - - #[test] - fn parse_chat_completions_request() { - const CHAT_COMPLETIONS_REQUEST: &str = r#" -{ - "model": "None", - "messages": [ - { - "role": "user", - "content": "how is the weather in seattle" - } - ], - "stream": true -} "#; - - let _chat_completions_request: ChatCompletionsRequest = - ChatCompletionsRequest::try_from(CHAT_COMPLETIONS_REQUEST.as_bytes()) - .expect("Failed to parse ChatCompletionsRequest"); - } - - #[test] - fn stream_chunk_parse_claude() { - const CHUNK_RESPONSE: &str = r#"data: {"id":"msg_01DZDMxYSgq8aPQxMQoBv6Kb","choices":[{"index":0,"delta":{"role":"assistant"}}],"created":1747685264,"model":"claude-3-7-sonnet-latest","object":"chat.completion.chunk"} - -data: {"type": "ping"} - -data: {"id":"msg_01DZDMxYSgq8aPQxMQoBv6Kb","choices":[{"index":0,"delta":{"content":"Hello!"}}],"created":1747685264,"model":"claude-3-7-sonnet-latest","object":"chat.completion.chunk"} - -data: {"id":"msg_01DZDMxYSgq8aPQxMQoBv6Kb","choices":[{"index":0,"delta":{"content":" How can I assist you today? Whether"}}],"created":1747685264,"model":"claude-3-7-sonnet-latest","object":"chat.completion.chunk"} - -data: {"id":"msg_01DZDMxYSgq8aPQxMQoBv6Kb","choices":[{"index":0,"delta":{"content":" you have a question, need information"}}],"created":1747685264,"model":"claude-3-7-sonnet-latest","object":"chat.completion.chunk"} - -data: {"id":"msg_01DZDMxYSgq8aPQxMQoBv6Kb","choices":[{"index":0,"delta":{"content":", or just want to chat about"}}],"created":1747685264,"model":"claude-3-7-sonnet-latest","object":"chat.completion.chunk"} - -data: {"id":"msg_01DZDMxYSgq8aPQxMQoBv6Kb","choices":[{"index":0,"delta":{"content":" something, I'm here to help. What woul"}}],"created":1747685264,"model":"claude-3-7-sonnet-latest","object":"chat.completion.chunk"} - -data: {"id":"msg_01DZDMxYSgq8aPQxMQoBv6Kb","choices":[{"index":0,"delta":{"content":"d you like to talk about?"}}],"created":1747685264,"model":"claude-3-7-sonnet-latest","object":"chat.completion.chunk"} - -data: {"id":"msg_01DZDMxYSgq8aPQxMQoBv6Kb","choices":[{"index":0,"delta":{},"finish_reason":"stop"}],"created":1747685264,"model":"claude-3-7-sonnet-latest","object":"chat.completion.chunk"} - -data: [DONE] -"#; - - let iter = SseChatCompletionIter::try_from(CHUNK_RESPONSE.as_bytes()); - - assert!(iter.is_ok(), "Failed to create SSE iterator"); - let iter: SseChatCompletionIter> = iter.unwrap(); - - let all_text: Vec = iter - .map(|item| { - let response = item.expect("Failed to parse response"); - response - .choices - .into_iter() - .filter_map(|choice| choice.delta.content) - .map(|content| content.to_string()) - .collect::() - }) - .collect(); - - assert_eq!( - all_text.len(), - 8, - "Expected 8 chunks of text, but got {}", - all_text.len() - ); - - assert_eq!( - all_text.join(""), - "Hello! How can I assist you today? Whether you have a question, need information, or just want to chat about something, I'm here to help. What would you like to talk about?" - ); - } -} diff --git a/crates/hermesllm/src/providers/traits.rs b/crates/hermesllm/src/providers/traits.rs index 2689ac8b..5148ed7a 100644 --- a/crates/hermesllm/src/providers/traits.rs +++ b/crates/hermesllm/src/providers/traits.rs @@ -6,59 +6,6 @@ use std::error::Error; use std::fmt; -/// Conversion mode for provider requests/responses -#[derive(Debug, Clone, Copy)] -pub enum ConversionMode { - /// Compatible: Convert between different provider formats to ensure compatibility - Compatible, - /// Passthrough: Pass requests/responses through with minimal modification - Passthrough, -} - -/// Error types for provider operations -#[derive(Debug)] -pub struct ProviderRequestError { - pub message: String, - pub source: Option>, -} - -#[derive(Debug)] -pub struct ProviderResponseError { - pub message: String, - pub source: Option>, -} - -impl fmt::Display for ProviderRequestError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "Provider request error: {}", self.message) - } -} - -impl fmt::Display for ProviderResponseError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "Provider response error: {}", self.message) - } -} - -impl Error for ProviderRequestError { - fn source(&self) -> Option<&(dyn Error + 'static)> { - self.source.as_ref().map(|e| e.as_ref() as &(dyn Error + 'static)) - } -} - -impl Error for ProviderResponseError { - fn source(&self) -> Option<&(dyn Error + 'static)> { - self.source.as_ref().map(|e| e.as_ref() as &(dyn Error + 'static)) - } -} - -/// Trait for token usage information -pub trait TokenUsage { - fn completion_tokens(&self) -> usize; - fn prompt_tokens(&self) -> usize; - fn total_tokens(&self) -> usize; -} - /// Trait for provider-specific request types pub trait ProviderRequest: Send + Sync { /// Extract the model name from the request @@ -107,13 +54,26 @@ pub trait ProviderStreamResponse: Send + Sync { } /// Trait for streaming response iterators -/// -/// This trait ensures that implementing types are iterators that yield -/// ProviderStreamResponse results. pub trait ProviderStreamResponseIter: Iterator, Box>> + Send + Sync { // No additional methods needed - just the Iterator constraint with proper bounds } +/// Conversion mode for provider requests/responses +#[derive(Debug, Clone, Copy)] +pub enum ConversionMode { + /// Compatible: Convert between different provider formats to ensure compatibility + Compatible, + /// Passthrough: Pass requests/responses through with minimal modification + Passthrough, +} + +/// Trait for token usage information +pub trait TokenUsage { + fn completion_tokens(&self) -> usize; + fn prompt_tokens(&self) -> usize; + fn total_tokens(&self) -> usize; +} + // ============================================================================ // PROVIDER FUNCTIONS - NO TRAITS, JUST PARAMETERIZED CONVERSION // ============================================================================ @@ -152,14 +112,42 @@ pub trait ProviderStreamResponseIter: Iterator Result, ProviderRequestError> { +// ============================================================================ +// PROVIDER ADAPTER REGISTRY (Organizational Enhancement) +// ============================================================================ + +/// Provider adapter configuration +#[derive(Debug, Clone)] +pub struct ProviderConfig { + pub supported_apis: &'static [&'static str], + pub adapter_type: AdapterType, +} + +#[derive(Debug, Clone)] +pub enum AdapterType { + OpenAICompatible, + // Future: Claude, Gemini, etc. +} + +/// Get provider configuration +pub fn get_provider_config(provider_id: &ProviderId) -> ProviderConfig { match provider_id { - // All these providers currently use OpenAI-compatible chat completions API - // In the future, we can add provider-specific handling in separate match arms ProviderId::OpenAI | ProviderId::Groq | ProviderId::Mistral | ProviderId::Deepseek | ProviderId::Arch | ProviderId::Gemini | ProviderId::Claude | ProviderId::GitHub => { + ProviderConfig { + supported_apis: &["/v1/chat/completions"], + adapter_type: AdapterType::OpenAICompatible, + } + } + } +} +/// Parse request from bytes using provider ID - returns generic ProviderRequest trait object +pub fn try_request_from_bytes(bytes: &[u8], provider_id: &ProviderId) -> Result, ProviderRequestError> { + let config = get_provider_config(provider_id); + + match config.adapter_type { + AdapterType::OpenAICompatible => { let request = crate::apis::openai::ChatCompletionsRequest::try_from((bytes, provider_id)) .map_err(|e| ProviderRequestError { message: format!("Failed to parse request: {}", e), @@ -175,9 +163,10 @@ pub fn try_request_from_bytes(bytes: &[u8], provider_id: &ProviderId) -> Result< /// Parse response from bytes using provider ID - returns generic ProviderResponse trait object pub fn try_response_from_bytes(bytes: &[u8], provider_id: &ProviderId, _mode: ConversionMode) -> Result, ProviderResponseError> { - match provider_id { - ProviderId::OpenAI | ProviderId::Groq | ProviderId::Mistral | ProviderId::Deepseek - | ProviderId::Arch | ProviderId::Gemini | ProviderId::Claude | ProviderId::GitHub => { + let config = get_provider_config(provider_id); + + match config.adapter_type { + AdapterType::OpenAICompatible => { // Parameterized conversion allows provider-specific response parsing let response = crate::apis::openai::ChatCompletionsResponse::try_from((bytes, provider_id)) .map_err(|e| ProviderResponseError { @@ -192,13 +181,11 @@ pub fn try_response_from_bytes(bytes: &[u8], provider_id: &ProviderId, _mode: Co } /// Create streaming response using provider ID - returns clean ProviderStreamResponseIter trait object -/// -/// This function returns a ProviderStreamResponseIter that's just an iterator, -/// eliminating the complex nested Result>> type completely. pub fn try_streaming_from_bytes(bytes: &[u8], provider_id: &ProviderId, _mode: ConversionMode) -> Result, Box> { - match provider_id { - ProviderId::OpenAI | ProviderId::Groq | ProviderId::Mistral | ProviderId::Deepseek - | ProviderId::Arch | ProviderId::Gemini | ProviderId::Claude | ProviderId::GitHub => { + let config = get_provider_config(provider_id); + + match config.adapter_type { + AdapterType::OpenAICompatible => { // Parse SSE (Server-Sent Events) streaming data let s = std::str::from_utf8(bytes)?; let lines: Vec = s.lines().map(|line| line.to_string()).collect(); @@ -211,29 +198,50 @@ pub fn try_streaming_from_bytes(bytes: &[u8], provider_id: &ProviderId, _mode: C } /// Check if provider has compatible API -/// -/// Replaces the old ProviderInterface::has_compatible_api method. -/// This function enables runtime API compatibility checking without needing a provider instance. pub fn has_compatible_api(provider_id: &ProviderId, api_path: &str) -> bool { - match provider_id { - // Currently all these providers support OpenAI chat completions API - // Future providers with different APIs will get their own match arms - ProviderId::OpenAI | ProviderId::Groq | ProviderId::Mistral | ProviderId::Deepseek - | ProviderId::Arch | ProviderId::Gemini | ProviderId::Claude | ProviderId::GitHub => { - api_path == "/v1/chat/completions" - } - } + let config = get_provider_config(provider_id); + config.supported_apis.iter().any(|&supported| supported == api_path) } /// Get supported APIs for provider -/// -/// Replaces the old ProviderInterface::supported_apis method. -/// Returns a static list of supported API endpoints for the given provider. pub fn supported_apis(provider_id: &ProviderId) -> Vec<&'static str> { - match provider_id { - ProviderId::OpenAI | ProviderId::Groq | ProviderId::Mistral | ProviderId::Deepseek - | ProviderId::Arch | ProviderId::Gemini | ProviderId::Claude | ProviderId::GitHub => { - vec!["/v1/chat/completions"] - } + let config = get_provider_config(provider_id); + config.supported_apis.to_vec() +} + +/// Error types for provider operations +#[derive(Debug)] +pub struct ProviderRequestError { + pub message: String, + pub source: Option>, +} + +#[derive(Debug)] +pub struct ProviderResponseError { + pub message: String, + pub source: Option>, +} + +impl fmt::Display for ProviderRequestError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "Provider request error: {}", self.message) + } +} + +impl fmt::Display for ProviderResponseError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "Provider response error: {}", self.message) + } +} + +impl Error for ProviderRequestError { + fn source(&self) -> Option<&(dyn Error + 'static)> { + self.source.as_ref().map(|e| e.as_ref() as &(dyn Error + 'static)) + } +} + +impl Error for ProviderResponseError { + fn source(&self) -> Option<&(dyn Error + 'static)> { + self.source.as_ref().map(|e| e.as_ref() as &(dyn Error + 'static)) } } diff --git a/crates/llm_gateway/src/stream_context.rs b/crates/llm_gateway/src/stream_context.rs index 5552cc54..0d981297 100644 --- a/crates/llm_gateway/src/stream_context.rs +++ b/crates/llm_gateway/src/stream_context.rs @@ -352,17 +352,16 @@ impl HttpContext for StreamContext { }; // Set the resolved model using the trait method - deserialized_body.set_model(resolved_model); + deserialized_body.set_model(resolved_model.clone()); // Extract user message for tracing self.user_message = deserialized_body.extract_user_message(); info!( - "on_http_request_body: provider: {}, model requested (in body): {}, model selected: {}, final model: {}", + "on_http_request_body: provider: {}, model requested (in body): {}, model selected: {}", self.llm_provider().name, model_requested, model_name.unwrap_or(&"None".to_string()), - deserialized_body.model(), ); // Use provider interface for streaming detection and setup @@ -376,7 +375,7 @@ impl HttpContext for StreamContext { // Use provider interface for text extraction (after potential mutation) let input_tokens_str = deserialized_body.extract_messages_text(); // enforce ratelimits on ingress - if let Err(e) = self.enforce_ratelimits(&model_requested, input_tokens_str.as_str()) { + if let Err(e) = self.enforce_ratelimits(&resolved_model, input_tokens_str.as_str()) { self.send_server_error( ServerError::ExceededRatelimit(e), Some(StatusCode::TOO_MANY_REQUESTS), diff --git a/crates/llm_gateway/tests/integration.rs b/crates/llm_gateway/tests/integration.rs index 108ab1ce..82ae8322 100644 --- a/crates/llm_gateway/tests/integration.rs +++ b/crates/llm_gateway/tests/integration.rs @@ -12,7 +12,7 @@ fn wasm_module() -> String { wasm_file.exists(), "Run `cargo build --release --target=wasm32-wasip1` first" ); - wasm_file.to_str().unwrap().to_string() + wasm_file.to_string_lossy().to_string() } fn request_headers_expectations(module: &mut Tester, http_context: i32) { @@ -267,17 +267,12 @@ fn llm_gateway_bad_request_to_open_ai_chat_completions() { .returning(Some(incomplete_chat_completions_request_body)) .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Info), Some("on_http_request_body: provider: open-ai-gpt-4, model requested (in body): gpt-1, model selected: gpt-4")) - .expect_send_local_response( - Some(StatusCode::BAD_REQUEST.as_u16().into()), - None, - None, - None, - ) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Debug), Some("getting token count model=gpt-4")) + .expect_log(Some(LogLevel::Debug), Some("Recorded input token count: 13")) .expect_metric_record("input_sequence_length", 13) - .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Debug), Some("Applying ratelimit for model: gpt-4")) + .expect_log(Some(LogLevel::Debug), Some(r#"Checking limit for provider=gpt-4, with selector=Header { key: "selector-key", value: "selector-value" }, consuming tokens=13"#)) + .expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None) .execute_and_expect(ReturnType::Action(Action::Continue)) .unwrap(); } @@ -386,11 +381,11 @@ fn llm_gateway_request_not_ratelimited() { .returning(Some(chat_completions_request_body)) // The actual call is not important in this test, we just need to grab the token_id .expect_log(Some(LogLevel::Info), None) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Debug), Some("getting token count model=gpt-4")) + .expect_log(Some(LogLevel::Debug), Some("Recorded input token count: 29")) .expect_metric_record("input_sequence_length", 29) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Debug), Some("Applying ratelimit for model: gpt-4")) + .expect_log(Some(LogLevel::Debug), Some(r#"Checking limit for provider=gpt-4, with selector=Header { key: "selector-key", value: "selector-value" }, consuming tokens=29"#)) .expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None) .execute_and_expect(ReturnType::Action(Action::Continue)) .unwrap(); @@ -433,11 +428,11 @@ fn llm_gateway_override_model_name() { // The actual call is not important in this test, we just need to grab the token_id .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Info), Some("on_http_request_body: provider: open-ai-gpt-4, model requested (in body): gpt-1, model selected: gpt-4")) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Debug), Some("getting token count model=gpt-4")) + .expect_log(Some(LogLevel::Debug), Some("Recorded input token count: 29")) .expect_metric_record("input_sequence_length", 29) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Debug), Some("Applying ratelimit for model: gpt-4")) + .expect_log(Some(LogLevel::Debug), Some(r#"Checking limit for provider=gpt-4, with selector=Header { key: "selector-key", value: "selector-value" }, consuming tokens=29"#)) .expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None) .execute_and_expect(ReturnType::Action(Action::Continue)) .unwrap(); @@ -483,8 +478,8 @@ fn llm_gateway_override_use_default_model() { Some(LogLevel::Info), Some("on_http_request_body: provider: open-ai-gpt-4, model requested (in body): gpt-1, model selected: gpt-4"), ) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Debug), Some("getting token count model=gpt-4")) + .expect_log(Some(LogLevel::Debug), Some("Recorded input token count: 29")) .expect_metric_record("input_sequence_length", 29) .expect_log(Some(LogLevel::Debug), Some("Applying ratelimit for model: gpt-4")) .expect_log(Some(LogLevel::Debug), Some(r#"Checking limit for provider=gpt-4, with selector=Header { key: "selector-key", value: "selector-value" }, consuming tokens=29"#)) @@ -530,11 +525,11 @@ fn llm_gateway_override_use_model_name_none() { // The actual call is not important in this test, we just need to grab the token_id .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Info), Some("on_http_request_body: provider: open-ai-gpt-4, model requested (in body): none, model selected: gpt-4")) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Debug), Some("getting token count model=gpt-4")) + .expect_log(Some(LogLevel::Debug), Some("Recorded input token count: 29")) .expect_metric_record("input_sequence_length", 29) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Debug), Some("Applying ratelimit for model: gpt-4")) + .expect_log(Some(LogLevel::Debug), Some(r#"Checking limit for provider=gpt-4, with selector=Header { key: "selector-key", value: "selector-value" }, consuming tokens=29"#)) .expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None) .execute_and_expect(ReturnType::Action(Action::Continue)) .unwrap();