diff --git a/crates/brightstaff/src/handlers/chat_completions.rs b/crates/brightstaff/src/handlers/chat_completions.rs index 37da961f..d0e5910a 100644 --- a/crates/brightstaff/src/handlers/chat_completions.rs +++ b/crates/brightstaff/src/handlers/chat_completions.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use bytes::Bytes; use common::configuration::ModelUsagePreference; use common::consts::ARCH_PROVIDER_HINT_HEADER; -use hermesllm::providers::openai::types::ChatCompletionsRequest; +use hermesllm::apis::openai::ChatCompletionsRequest; use http_body_util::combinators::BoxBody; use http_body_util::{BodyExt, Full, StreamBody}; use hyper::body::Frame; @@ -93,7 +93,7 @@ pub async fn chat_completions( chat_completion_request.metadata.and_then(|metadata| { metadata .get("archgw_preference_config") - .and_then(|value| value.as_str().map(String::from)) + .map(|value| value.to_string()) }); let usage_preferences: Option> = usage_preferences_str @@ -105,9 +105,7 @@ pub async fn chat_completions( .messages .last() .map_or("None".to_string(), |msg| { - msg.content.as_ref().map_or("None".to_string(), |content| { - content.to_string().replace('\n', "\\n") - }) + msg.content.to_string().replace('\n', "\\n") }); const MAX_MESSAGE_LENGTH: usize = 50; diff --git a/crates/brightstaff/src/handlers/models.rs b/crates/brightstaff/src/handlers/models.rs index 3a4662a6..ac1bbebe 100644 --- a/crates/brightstaff/src/handlers/models.rs +++ b/crates/brightstaff/src/handlers/models.rs @@ -1,6 +1,6 @@ use bytes::Bytes; use common::configuration::{IntoModels, LlmProvider}; -use hermesllm::providers::openai::types::Models; +use hermesllm::apis::openai::Models; use http_body_util::{combinators::BoxBody, BodyExt, Full}; use hyper::{Response, StatusCode}; use serde_json; diff --git a/crates/brightstaff/src/main.rs b/crates/brightstaff/src/main.rs index b5bf0204..34fa3aa3 100644 --- a/crates/brightstaff/src/main.rs +++ b/crates/brightstaff/src/main.rs @@ -98,7 +98,7 @@ async fn main() -> Result<(), Box> { let peer_addr = stream.peer_addr()?; let io = TokioIo::new(stream); - let router_service = Arc::clone(&router_service); + let router_service: Arc = Arc::clone(&router_service); let llm_provider_endpoint = llm_provider_endpoint.clone(); let llm_providers = llm_providers.clone(); diff --git a/crates/brightstaff/src/router/llm_router.rs b/crates/brightstaff/src/router/llm_router.rs index fc6d9365..3b09c115 100644 --- a/crates/brightstaff/src/router/llm_router.rs +++ b/crates/brightstaff/src/router/llm_router.rs @@ -4,7 +4,7 @@ use common::{ configuration::{LlmProvider, ModelUsagePreference, RoutingPreference}, consts::ARCH_PROVIDER_HINT_HEADER, }; -use hermesllm::providers::openai::types::{ChatCompletionsResponse, ContentType, Message}; +use hermesllm::apis::openai::{ChatCompletionsResponse, Message}; use hyper::header; use thiserror::Error; use tracing::{debug, info, warn}; @@ -153,9 +153,7 @@ impl RouterService { return Ok(None); } - if let Some(ContentType::Text(content)) = - &chat_completion_response.choices[0].message.content - { + if let Some(content) = &chat_completion_response.choices[0].message.content { let parsed_response = self .router_model .parse_response(content, &usage_preferences)?; diff --git a/crates/brightstaff/src/router/router_model.rs b/crates/brightstaff/src/router/router_model.rs index ec0c1a1f..372907af 100644 --- a/crates/brightstaff/src/router/router_model.rs +++ b/crates/brightstaff/src/router/router_model.rs @@ -1,5 +1,5 @@ use common::configuration::ModelUsagePreference; -use hermesllm::providers::openai::types::{ChatCompletionsRequest, Message}; +use hermesllm::apis::openai::{ChatCompletionsRequest, Message}; use thiserror::Error; #[derive(Debug, Error)] diff --git a/crates/brightstaff/src/router/router_model_v1.rs b/crates/brightstaff/src/router/router_model_v1.rs index bd06b525..1c1c14ef 100644 --- a/crates/brightstaff/src/router/router_model_v1.rs +++ b/crates/brightstaff/src/router/router_model_v1.rs @@ -2,9 +2,8 @@ use std::collections::HashMap; use common::{ configuration::{ModelUsagePreference, RoutingPreference}, - consts::{SYSTEM_ROLE, TOOL_ROLE, USER_ROLE}, }; -use hermesllm::providers::openai::types::{ChatCompletionsRequest, ContentType, Message}; +use hermesllm::apis::openai::{ChatCompletionsRequest, MessageContent, Message, Role}; use serde::{Deserialize, Serialize}; use tracing::{debug, warn}; @@ -80,7 +79,9 @@ impl RouterModel for RouterModelV1 { // when role == tool its tool call response let messages_vec = messages .iter() - .filter(|m| m.role != SYSTEM_ROLE && m.role != TOOL_ROLE && m.content.is_some()) + .filter(|m| { + m.role != Role::System && m.role != Role::Tool && !m.content.to_string().is_empty() + }) .collect::>(); // Following code is to ensure that the conversation does not exceed max token length @@ -88,13 +89,7 @@ impl RouterModel for RouterModelV1 { let mut token_count = ARCH_ROUTER_V1_SYSTEM_PROMPT.len() / TOKEN_LENGTH_DIVISOR; let mut selected_messages_list_reversed: Vec<&Message> = vec![]; for (selected_messsage_count, message) in messages_vec.iter().rev().enumerate() { - let message_token_count = message - .content - .as_ref() - .unwrap_or(&ContentType::Text("".to_string())) - .to_string() - .len() - / TOKEN_LENGTH_DIVISOR; + let message_token_count = message.content.to_string().len() / TOKEN_LENGTH_DIVISOR; token_count += message_token_count; if token_count > self.max_token_length { debug!( @@ -104,7 +99,7 @@ impl RouterModel for RouterModelV1 { , selected_messsage_count, messages_vec.len() ); - if message.role == USER_ROLE { + if message.role == Role::User { // If message that exceeds max token length is from user, we need to keep it selected_messages_list_reversed.push(message); } @@ -125,12 +120,12 @@ impl RouterModel for RouterModelV1 { // ensure that first and last selected message is from user if let Some(first_message) = selected_messages_list_reversed.first() { - if first_message.role != USER_ROLE { + if first_message.role != Role::User { warn!("RouterModelV1: last message in the conversation is not from user, this may lead to incorrect routing"); } } if let Some(last_message) = selected_messages_list_reversed.last() { - if last_message.role != USER_ROLE { + if last_message.role != Role::User { warn!("RouterModelV1: first message in the conversation is not from user, this may lead to incorrect routing"); } } @@ -143,9 +138,10 @@ impl RouterModel for RouterModelV1 { Message { role: message.role.clone(), // we can unwrap here because we have already filtered out messages without content - content: Some(ContentType::Text( - message.content.as_ref().unwrap().to_string(), - )), + content: MessageContent::Text(message.content.to_string()), + name: None, + tool_calls: None, + tool_call_id: None, } }) .collect::>(); @@ -160,8 +156,11 @@ impl RouterModel for RouterModelV1 { ChatCompletionsRequest { model: self.routing_model.clone(), messages: vec![Message { - content: Some(ContentType::Text(router_message)), - role: USER_ROLE.to_string(), + content: MessageContent::Text(router_message), + role: Role::User, + name: None, + tool_calls: None, + tool_call_id: None, }], temperature: Some(0.01), ..Default::default() @@ -347,9 +346,9 @@ Based on your analysis, provide your response in the following JSON formats if y let req = router.generate_request(&conversation, &None); - let prompt = req.messages[0].content.as_ref().unwrap(); + let prompt = req.messages[0].content.to_string(); - assert_eq!(expected_prompt, prompt.to_string()); + assert_eq!(expected_prompt, prompt); } #[test] @@ -412,9 +411,9 @@ Based on your analysis, provide your response in the following JSON formats if y }]); let req = router.generate_request(&conversation, &usage_preferences); - let prompt = req.messages[0].content.as_ref().unwrap(); + let prompt = req.messages[0].content.to_string(); - assert_eq!(expected_prompt, prompt.to_string()); + assert_eq!(expected_prompt, prompt); } #[test] @@ -472,9 +471,9 @@ Based on your analysis, provide your response in the following JSON formats if y let req = router.generate_request(&conversation, &None); - let prompt = req.messages[0].content.as_ref().unwrap(); + let prompt = req.messages[0].content.to_string(); - assert_eq!(expected_prompt, prompt.to_string()); + assert_eq!(expected_prompt, prompt); } #[test] @@ -533,9 +532,9 @@ Based on your analysis, provide your response in the following JSON formats if y let req = router.generate_request(&conversation, &None); - let prompt = req.messages[0].content.as_ref().unwrap(); + let prompt = req.messages[0].content.to_string(); - assert_eq!(expected_prompt, prompt.to_string()); + assert_eq!(expected_prompt, prompt); } #[test] @@ -601,9 +600,9 @@ Based on your analysis, provide your response in the following JSON formats if y let req = router.generate_request(&conversation, &None); - let prompt = req.messages[0].content.as_ref().unwrap(); + let prompt = req.messages[0].content.to_string(); - assert_eq!(expected_prompt, prompt.to_string()); + assert_eq!(expected_prompt, prompt); } #[test] @@ -670,9 +669,9 @@ Based on your analysis, provide your response in the following JSON formats if y let req = router.generate_request(&conversation, &None); - let prompt = req.messages[0].content.as_ref().unwrap(); + let prompt = req.messages[0].content.to_string(); - assert_eq!(expected_prompt, prompt.to_string()); + assert_eq!(expected_prompt, prompt); } #[test] @@ -716,14 +715,14 @@ Based on your analysis, provide your response in the following JSON formats if y }, { "role": "assistant", - "content": null, + "content": "", "tool_calls": [ { "id": "toolcall-abc123", "type": "function", "function": { "name": "get_weather", - "arguments": { "location": "Tokyo" } + "arguments": "{ \"location\": \"Tokyo\" }" } } ] @@ -763,11 +762,11 @@ Based on your analysis, provide your response in the following JSON formats if y let conversation: Vec = serde_json::from_str(conversation_str).unwrap(); - let req = router.generate_request(&conversation, &None); + let req: ChatCompletionsRequest = router.generate_request(&conversation, &None); - let prompt = req.messages[0].content.as_ref().unwrap(); + let prompt = req.messages[0].content.to_string(); - assert_eq!(expected_prompt, prompt.to_string()); + assert_eq!(expected_prompt, prompt); } #[test] diff --git a/crates/common/src/configuration.rs b/crates/common/src/configuration.rs index 186691dc..20d2623b 100644 --- a/crates/common/src/configuration.rs +++ b/crates/common/src/configuration.rs @@ -1,4 +1,4 @@ -use hermesllm::providers::openai::types::{ModelDetail, ModelObject, Models}; +use hermesllm::apis::openai::{ModelDetail, ModelObject, Models}; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::fmt::Display; @@ -177,6 +177,14 @@ impl Display for LlmProviderType { } } +impl LlmProviderType { + /// Get the ProviderId for this LlmProviderType + /// Used with the new function-based hermesllm API + pub fn to_provider_id(&self) -> hermesllm::ProviderId { + hermesllm::ProviderId::from(self.to_string().as_str()) + } +} + #[derive(Serialize, Deserialize, Debug)] pub struct ModelUsagePreference { pub model: String, @@ -252,6 +260,14 @@ impl Display for LlmProvider { } } +impl LlmProvider { + /// Get the ProviderId for this LlmProvider + /// Used with the new function-based hermesllm API + pub fn to_provider_id(&self) -> hermesllm::ProviderId { + self.provider_interface.to_provider_id() + } +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Endpoint { pub endpoint: Option, diff --git a/crates/common/src/errors.rs b/crates/common/src/errors.rs index 582c0a7c..21af3c94 100644 --- a/crates/common/src/errors.rs +++ b/crates/common/src/errors.rs @@ -1,7 +1,7 @@ use proxy_wasm::types::Status; use crate::{api::open_ai::ChatCompletionChunkResponseError, ratelimit}; -use hermesllm::providers::openai::types::OpenAIError; +use hermesllm::apis::openai::OpenAIError; #[derive(thiserror::Error, Debug)] pub enum ClientError { diff --git a/crates/hermesllm/README.md b/crates/hermesllm/README.md index 807e63f6..6f6b99e1 100644 --- a/crates/hermesllm/README.md +++ b/crates/hermesllm/README.md @@ -1,63 +1,145 @@ # hermesllm -A Rust library for translating LLM (Large Language Model) API requests and responses between Mistral, Groq, Gemini, Deepseek, OpenAI, and other provider-compliant formats. +A Rust library for handling LLM (Large Language Model) API requests and responses with unified abstractions across multiple providers. ## Features -- Unified types for chat completions and model metadata across multiple LLM providers -- Builder-pattern API for constructing requests in an idiomatic Rust style -- Easy conversion between provider formats -- Streaming and non-streaming response support +- Unified request/response types with provider-specific parsing +- Support for both streaming and non-streaming responses +- Type-safe provider identification +- OpenAI-compatible API structure with extensible provider support ## Supported Providers -- Mistral -- Deepseek -- Groq -- Gemini - OpenAI +- Mistral +- Groq +- Deepseek +- Gemini - Claude -- Github +- GitHub ## Installation -Add the following to your `Cargo.toml`: +Add to your `Cargo.toml`: ```toml [dependencies] -hermesllm = { git = "https://github.com/katanemo/archgw", subdir = "crates/hermesllm" } +hermesllm = { path = "../hermesllm" } # or appropriate path in workspace ``` -_Replace the path with the appropriate location if using as a workspace member or published crate._ - ## Usage -Construct a chat completion request using the builder pattern: +### Basic Request Parsing ```rust -use hermesllm::Provider; -use hermesllm::providers::openai::types::ChatCompletionsRequest; +use hermesllm::providers::{ProviderRequestType, ProviderRequest, ProviderId}; -let request = ChatCompletionsRequest::builder("gpt-3.5-turbo", vec![Message::new("Hi".to_string())]) - .build() - .expect("Failed to build OpenAIRequest"); +// Parse request from JSON bytes +let request_bytes = r#"{"model": "gpt-4", "messages": [{"role": "user", "content": "Hello!"}]}"#; -// Convert to bytes for a specific provider -let bytes = request.to_bytes(Provider::OpenAI)?; +// Parse with provider context +let request = ProviderRequestType::try_from((request_bytes.as_bytes(), &ProviderId::OpenAI))?; + +// Access request properties +println!("Model: {}", request.model()); +println!("User message: {:?}", request.get_recent_user_message()); +println!("Is streaming: {}", request.is_streaming()); ``` -## API Overview +### Working with Responses -- `Provider`: Enum listing all supported LLM providers. -- `ChatCompletionsRequest`: Builder-pattern struct for creating chat completion requests. -- `ChatCompletionsResponse`: Struct for parsing responses. -- Streaming support via `SseChatCompletionIter`. -- Error handling via `OpenAIError`. +```rust +use hermesllm::providers::{ProviderResponseType, ProviderResponse}; -## Contributing +// Parse response from provider +let response_bytes = /* JSON response from LLM */; +let response = ProviderResponseType::try_from((response_bytes, ProviderId::OpenAI))?; -Contributions are welcome! Please open issues or pull requests for bug fixes, new features, or provider integrations. +// Extract token usage +if let Some((prompt, completion, total)) = response.extract_usage_counts() { + println!("Tokens used: {}/{}/{}", prompt, completion, total); +} +``` + +### Handling Streaming Responses + +```rust +use hermesllm::providers::{ProviderStreamResponseIter, ProviderStreamResponse}; + +// Create streaming iterator from SSE data +let sse_data = /* Server-Sent Events data */; +let mut stream = ProviderStreamResponseIter::try_from((sse_data, &ProviderId::OpenAI))?; + +// Process streaming chunks +for chunk_result in stream { + match chunk_result { + Ok(chunk) => { + if let Some(content) = chunk.content_delta() { + print!("{}", content); + } + if chunk.is_final() { + break; + } + } + Err(e) => eprintln!("Stream error: {}", e), + } +} +``` + +### Provider Compatibility + +```rust +use hermesllm::providers::{ProviderId, has_compatible_api, supported_apis}; + +// Check API compatibility +let provider = ProviderId::Groq; +if has_compatible_api(&provider, "/v1/chat/completions") { + println!("Provider supports chat completions"); +} + +// List supported APIs +let apis = supported_apis(&provider); +println!("Supported APIs: {:?}", apis); +``` + +## Core Types + +### Provider Types +- `ProviderId` - Enum identifying supported providers (OpenAI, Mistral, Groq, etc.) +- `ProviderRequestType` - Enum wrapping provider-specific request types +- `ProviderResponseType` - Enum wrapping provider-specific response types +- `ProviderStreamResponseIter` - Iterator for streaming response chunks + +### Traits +- `ProviderRequest` - Common interface for all request types +- `ProviderResponse` - Common interface for all response types +- `ProviderStreamResponse` - Interface for streaming response chunks +- `TokenUsage` - Interface for token usage information + +### OpenAI API Types +- `ChatCompletionsRequest` - Chat completion request structure +- `ChatCompletionsResponse` - Chat completion response structure +- `Message`, `Role`, `MessageContent` - Message building blocks + +## Architecture + +The library uses a type-safe enum-based approach that: + +- **Provides Type Safety**: All provider operations are checked at compile time +- **Enables Runtime Provider Selection**: Provider can be determined from request headers or config +- **Maintains Clean Abstractions**: Common traits hide provider-specific details +- **Supports Extensibility**: New providers can be added by extending the enums + +All requests are parsed into a common `ProviderRequestType` enum which implements the `ProviderRequest` trait, allowing uniform access to request properties regardless of the underlying provider format. + +## Examples + +See the `src/lib.rs` tests for complete working examples of: +- Parsing requests with provider context +- Handling streaming responses +- Working with token usage information ## License -This project is licensed under the terms of the [MIT License](../LICENSE). +This project is licensed under the MIT License. diff --git a/crates/hermesllm/src/apis/openai.rs b/crates/hermesllm/src/apis/openai.rs index 7f75c6be..2471fc35 100644 --- a/crates/hermesllm/src/apis/openai.rs +++ b/crates/hermesllm/src/apis/openai.rs @@ -2,7 +2,13 @@ use serde::{Deserialize, Serialize}; use serde_json::Value; use serde_with::skip_serializing_none; use std::collections::HashMap; +use std::fmt::Display; +use thiserror::Error; + + +use crate::providers::request::{ProviderRequest, ProviderRequestError}; +use crate::providers::response::{ProviderResponse, ProviderStreamResponse, TokenUsage, SseStreamIter}; use super::ApiDefinition; // ============================================================================ @@ -115,8 +121,8 @@ pub enum Role { #[skip_serializing_none] #[derive(Serialize, Deserialize, Debug, Clone)] pub struct Message { - pub content: MessageContent, pub role: Role, + pub content: MessageContent, pub name: Option, /// Tool calls made by the assistant (only present for assistant role) pub tool_calls: Option>, @@ -124,8 +130,6 @@ pub struct Message { pub tool_call_id: Option, } - - #[skip_serializing_none] #[derive(Serialize, Deserialize, Debug, Clone)] pub struct ResponseMessage { @@ -170,6 +174,28 @@ pub enum MessageContent { Parts(Vec), } +impl Display for MessageContent { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + MessageContent::Text(text) => write!(f, "{}", text), + MessageContent::Parts(parts) => { + let text_parts: Vec = parts + .iter() + .filter_map(|part| match part { + ContentPart::Text { text } => Some(text.clone()), + ContentPart::ImageUrl { .. } => { + // skip image URLs or their data in text representation + None + } + }) + .collect(); + let combined_text = text_parts.join("\n"); + write!(f, "{}", combined_text) + } + } + } +} + /// Individual content part within a message (text or image) #[derive(Serialize, Deserialize, Debug, Clone)] #[serde(tag = "type")] @@ -424,6 +450,239 @@ pub struct StreamOptions { pub include_usage: Option, } +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ModelDetail { + pub id: String, + pub object: String, + pub created: usize, + pub owned_by: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum ModelObject { + #[serde(rename = "list")] + List, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Models { + pub object: ModelObject, + pub data: Vec, +} + + +// Error type for streaming operations +#[derive(Debug, thiserror::Error)] +pub enum OpenAIStreamError { + #[error("JSON parsing error: {0}")] + JsonError(#[from] serde_json::Error), + #[error("UTF-8 parsing error: {0}")] + Utf8Error(#[from] std::str::Utf8Error), + #[error("Invalid streaming data: {0}")] + InvalidStreamingData(String), +} + +#[derive(Debug, Error)] +pub enum OpenAIError { + #[error("json error: {0}")] + JsonParseError(#[from] serde_json::Error), + #[error("utf8 parsing error: {0}")] + Utf8Error(#[from] std::str::Utf8Error), + #[error("invalid streaming data err {source}, data: {data}")] + InvalidStreamingData { + source: serde_json::Error, + data: String, + }, + #[error("unsupported provider: {provider}")] + UnsupportedProvider { provider: String }, +} + +// ============================================================================ +/// Trait Implementations +/// =========================================================================== + + +/// Parameterized conversion for ChatCompletionsRequest +impl TryFrom<&[u8]> for ChatCompletionsRequest { + type Error = OpenAIStreamError; + + fn try_from(bytes: &[u8]) -> Result { + serde_json::from_slice(bytes).map_err(OpenAIStreamError::from) + } +} + +/// Parameterized conversion for ChatCompletionsResponse +impl TryFrom<&[u8]> for ChatCompletionsResponse { + type Error = OpenAIStreamError; + + fn try_from(bytes: &[u8]) -> Result { + serde_json::from_slice(bytes).map_err(OpenAIStreamError::from) + } +} + +/// Implementation of TokenUsage for OpenAI Usage type +impl TokenUsage for Usage { + fn completion_tokens(&self) -> usize { + self.completion_tokens as usize + } + + fn prompt_tokens(&self) -> usize { + self.prompt_tokens as usize + } + + fn total_tokens(&self) -> usize { + self.total_tokens as usize + } +} + +/// Implementation of ProviderRequest for ChatCompletionsRequest +impl ProviderRequest for ChatCompletionsRequest { + fn model(&self) -> &str { + &self.model + } + + fn set_model(&mut self, model: String) { + self.model = model; + } + + fn is_streaming(&self) -> bool { + self.stream.unwrap_or_default() + } + + fn extract_messages_text(&self) -> String { + self.messages.iter().fold(String::new(), |acc, m| { + acc + " " + &match &m.content { + MessageContent::Text(text) => text.clone(), + MessageContent::Parts(parts) => parts.iter().map(|part| match part { + ContentPart::Text { text } => text.clone(), + ContentPart::ImageUrl { .. } => "[Image]".to_string(), + }).collect::>().join(" ") + } + }) + } + + fn get_recent_user_message(&self) -> Option { + self.messages.last().and_then(|msg| { + match &msg.content { + MessageContent::Text(text) => Some(text.clone()), + MessageContent::Parts(_) => None, // No user message in parts + } + }) + } + + fn to_bytes(&self) -> Result, ProviderRequestError> { + serde_json::to_vec(&self).map_err(|e| ProviderRequestError { + message: format!("Failed to serialize OpenAI request: {}", e), + source: Some(Box::new(e)), + }) + } +} + +/// Implementation of ProviderResponse for ChatCompletionsResponse +impl ProviderResponse for ChatCompletionsResponse { + fn usage(&self) -> Option<&dyn TokenUsage> { + Some(&self.usage) + } + + fn extract_usage_counts(&self) -> Option<(usize, usize, usize)> { + Some(( + self.usage.prompt_tokens(), + self.usage.completion_tokens(), + self.usage.total_tokens(), + )) + } +} + +// ============================================================================ +// OPENAI SSE STREAMING ITERATOR +// ============================================================================ + +/// OpenAI-specific SSE streaming iterator +/// Handles OpenAI's specific SSE format and ChatCompletionsStreamResponse parsing +pub struct OpenAISseIter +where + I: Iterator, + I::Item: AsRef, +{ + sse_stream: SseStreamIter, +} + +impl OpenAISseIter +where + I: Iterator, + I::Item: AsRef, +{ + pub fn new(sse_stream: SseStreamIter) -> Self { + Self { sse_stream } + } +} + +impl Iterator for OpenAISseIter +where + I: Iterator, + I::Item: AsRef, +{ + type Item = Result, Box>; + + fn next(&mut self) -> Option { + for line in &mut self.sse_stream.lines { + let line = line.as_ref(); + if line.is_empty() { + continue; + } + + if line.starts_with("data: ") { + let data = &line[6..]; // Remove "data: " prefix + if data == "[DONE]" { + return None; + } + + // Skip ping messages (usually from other providers, but handle gracefully) + if data == r#"{"type": "ping"}"# { + continue; + } + + // OpenAI-specific parsing of ChatCompletionsStreamResponse + match serde_json::from_str::(data) { + Ok(response) => return Some(Ok(Box::new(response))), + Err(e) => return Some(Err(Box::new( + OpenAIStreamError::InvalidStreamingData(format!("Error parsing OpenAI streaming data: {}, data: {}", e, data)) + ))), + } + } + } + None + } +} + +// Direct implementation of ProviderStreamResponse trait on ChatCompletionsStreamResponse +impl ProviderStreamResponse for ChatCompletionsStreamResponse { + fn content_delta(&self) -> Option<&str> { + self.choices + .first() + .and_then(|choice| choice.delta.content.as_deref()) + } + + fn is_final(&self) -> bool { + self.choices + .first() + .map(|choice| choice.finish_reason.is_some()) + .unwrap_or(false) + } + + fn role(&self) -> Option<&str> { + self.choices + .first() + .and_then(|choice| choice.delta.role.as_ref().map(|r| match r { + Role::System => "system", + Role::User => "user", + Role::Assistant => "assistant", + Role::Tool => "tool", + })) + } +} + + #[cfg(test)] mod tests { use super::*; diff --git a/crates/hermesllm/src/clients/transformer.rs b/crates/hermesllm/src/clients/transformer.rs index c6d524f4..23ca26ee 100644 --- a/crates/hermesllm/src/clients/transformer.rs +++ b/crates/hermesllm/src/clients/transformer.rs @@ -13,14 +13,14 @@ //! //! ```rust //! use hermesllm::apis::{ -//! AnthropicMessagesRequest, ChatCompletionsRequest, MessagesRole, MessagesMessage, +//! MessagesRequest, ChatCompletionsRequest, MessagesRole, MessagesMessage, //! MessagesMessageContent, MessagesSystemPrompt, //! }; //! use hermesllm::clients::TransformError; //! use std::convert::TryInto; //! //! // Transform Anthropic to OpenAI -//! let anthropic_req = AnthropicMessagesRequest { +//! let anthropic_req = MessagesRequest { //! model: "claude-3-sonnet".to_string(), //! system: None, //! messages: vec![], diff --git a/crates/hermesllm/src/lib.rs b/crates/hermesllm/src/lib.rs index 169467a1..b4ad9932 100644 --- a/crates/hermesllm/src/lib.rs +++ b/crates/hermesllm/src/lib.rs @@ -5,77 +5,90 @@ pub mod providers; pub mod apis; pub mod clients; - -use std::fmt::Display; -pub enum Provider { - Arch, - Mistral, - Deepseek, - Groq, - Gemini, - OpenAI, - Claude, - Github, -} - -impl From<&str> for Provider { - fn from(value: &str) -> Self { - match value.to_lowercase().as_str() { - "arch" => Provider::Arch, - "mistral" => Provider::Mistral, - "deepseek" => Provider::Deepseek, - "groq" => Provider::Groq, - "gemini" => Provider::Gemini, - "openai" => Provider::OpenAI, - "claude" => Provider::Claude, - "github" => Provider::Github, - _ => panic!("Unknown provider: {}", value), - } - } -} - -impl Display for Provider { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Provider::Arch => write!(f, "Arch"), - Provider::Mistral => write!(f, "Mistral"), - Provider::Deepseek => write!(f, "Deepseek"), - Provider::Groq => write!(f, "Groq"), - Provider::Gemini => write!(f, "Gemini"), - Provider::OpenAI => write!(f, "OpenAI"), - Provider::Claude => write!(f, "Claude"), - Provider::Github => write!(f, "Github"), - } - } -} +// Re-export important types and traits +pub use providers::request::{ProviderRequestType, ProviderRequest, ProviderRequestError}; +pub use providers::response::{ProviderResponseType, ProviderResponse, ProviderStreamResponse, ProviderStreamResponseIter, ProviderResponseError, TokenUsage}; +pub use providers::id::ProviderId; +pub use providers::adapters::{has_compatible_api, supported_apis}; #[cfg(test)] mod tests { - use crate::providers::openai::types::{ChatCompletionsRequest, Message}; + use super::*; #[test] - fn openai_builder() { - let request = - ChatCompletionsRequest::builder("gpt-3.5-turbo", vec![Message::new("Hi".to_string())]) - .temperature(0.7) - .top_p(0.9) - .n(1) - .max_tokens(100) - .stream(false) - .stop(vec!["\n".to_string()]) - .presence_penalty(0.0) - .frequency_penalty(0.0) - .build() - .expect("Failed to build OpenAIRequest"); + fn test_provider_id_conversion() { + assert_eq!(ProviderId::from("openai"), ProviderId::OpenAI); + assert_eq!(ProviderId::from("mistral"), ProviderId::Mistral); + assert_eq!(ProviderId::from("groq"), ProviderId::Groq); + assert_eq!(ProviderId::from("arch"), ProviderId::Arch); + } - assert_eq!(request.model, "gpt-3.5-turbo"); - assert_eq!(request.temperature, Some(0.7)); - assert_eq!(request.top_p, Some(0.9)); - assert_eq!(request.n, Some(1)); - assert_eq!(request.max_tokens, Some(100)); - assert_eq!(request.stream, Some(false)); - assert_eq!(request.stop, Some(vec!["\n".to_string()])); - assert_eq!(request.presence_penalty, Some(0.0)); - assert_eq!(request.frequency_penalty, Some(0.0)); + #[test] + fn test_provider_api_compatibility() { + assert!(has_compatible_api(&ProviderId::OpenAI, "/v1/chat/completions")); + assert!(!has_compatible_api(&ProviderId::OpenAI, "/v1/embeddings")); + } + + #[test] + fn test_provider_supported_apis() { + let apis = supported_apis(&ProviderId::OpenAI); + assert!(apis.contains(&"/v1/chat/completions")); + + // Test that provider supports the expected API endpoints + assert!(has_compatible_api(&ProviderId::OpenAI, "/v1/chat/completions")); + } + + #[test] + fn test_provider_request_parsing() { + // Test with a sample JSON request + let json_request = r#"{ + "model": "gpt-4", + "messages": [ + { + "role": "system", + "content": "You are a helpful assistant" + }, + { + "role": "user", + "content": "Hello!" + } + ] + }"#; + + let result: Result = ProviderRequestType::try_from(json_request.as_bytes()); + assert!(result.is_ok()); + + let request = result.unwrap(); + assert_eq!(request.model(), "gpt-4"); + assert_eq!(request.get_recent_user_message(), Some("Hello!".to_string())); + } + + #[test] + fn test_provider_streaming_response() { + // Test streaming response parsing with sample SSE data + let sse_data = r#"data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268190,"model":"gpt-4","choices":[{"index":0,"delta":{"role":"assistant","content":"Hello"},"finish_reason":null}]} + +data: [DONE] +"#; + + let result = ProviderStreamResponseIter::try_from((sse_data.as_bytes(), &ProviderId::OpenAI)); + assert!(result.is_ok()); + + let mut streaming_response = result.unwrap(); + + // Test that we can iterate over chunks - it's just an iterator now! + let first_chunk = streaming_response.next(); + assert!(first_chunk.is_some()); + + let chunk_result = first_chunk.unwrap(); + assert!(chunk_result.is_ok()); + + let chunk = chunk_result.unwrap(); + assert_eq!(chunk.content_delta(), Some("Hello")); + assert!(!chunk.is_final()); + + // Test that stream ends properly + let final_chunk = streaming_response.next(); + assert!(final_chunk.is_none()); } } diff --git a/crates/hermesllm/src/providers/adapters.rs b/crates/hermesllm/src/providers/adapters.rs new file mode 100644 index 00000000..a001cf09 --- /dev/null +++ b/crates/hermesllm/src/providers/adapters.rs @@ -0,0 +1,39 @@ +use crate::providers::id::ProviderId; + +#[derive(Debug, Clone)] +pub enum AdapterType { + OpenAICompatible, + // Future: Claude, Gemini, etc. +} + +/// Provider adapter configuration +#[derive(Debug, Clone)] +pub struct ProviderConfig { + pub supported_apis: &'static [&'static str], + pub adapter_type: AdapterType, +} + +/// Check if provider has compatible API +pub fn has_compatible_api(provider_id: &ProviderId, api_path: &str) -> bool { + let config = get_provider_config(provider_id); + config.supported_apis.iter().any(|&supported| supported == api_path) +} + +/// Get supported APIs for provider +pub fn supported_apis(provider_id: &ProviderId) -> Vec<&'static str> { + let config = get_provider_config(provider_id); + config.supported_apis.to_vec() +} + +/// Get provider configuration +pub fn get_provider_config(provider_id: &ProviderId) -> ProviderConfig { + match provider_id { + ProviderId::OpenAI | ProviderId::Groq | ProviderId::Mistral | ProviderId::Deepseek + | ProviderId::Arch | ProviderId::Gemini | ProviderId::Claude | ProviderId::GitHub => { + ProviderConfig { + supported_apis: &["/v1/chat/completions"], + adapter_type: AdapterType::OpenAICompatible, + } + } + } +} diff --git a/crates/hermesllm/src/providers/id.rs b/crates/hermesllm/src/providers/id.rs new file mode 100644 index 00000000..2c0c494e --- /dev/null +++ b/crates/hermesllm/src/providers/id.rs @@ -0,0 +1,45 @@ +use std::fmt::Display; + +/// Provider identifier enum - simple enum for identifying providers +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum ProviderId { + OpenAI, + Mistral, + Deepseek, + Groq, + Gemini, + Claude, + GitHub, + Arch, +} + +impl From<&str> for ProviderId { + fn from(value: &str) -> Self { + match value.to_lowercase().as_str() { + "openai" => ProviderId::OpenAI, + "mistral" => ProviderId::Mistral, + "deepseek" => ProviderId::Deepseek, + "groq" => ProviderId::Groq, + "gemini" => ProviderId::Gemini, + "claude" => ProviderId::Claude, + "github" => ProviderId::GitHub, + "arch" => ProviderId::Arch, + _ => panic!("Unknown provider: {}", value), + } + } +} + +impl Display for ProviderId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ProviderId::OpenAI => write!(f, "OpenAI"), + ProviderId::Mistral => write!(f, "Mistral"), + ProviderId::Deepseek => write!(f, "Deepseek"), + ProviderId::Groq => write!(f, "Groq"), + ProviderId::Gemini => write!(f, "Gemini"), + ProviderId::Claude => write!(f, "Claude"), + ProviderId::GitHub => write!(f, "GitHub"), + ProviderId::Arch => write!(f, "Arch"), + } + } +} diff --git a/crates/hermesllm/src/providers/mod.rs b/crates/hermesllm/src/providers/mod.rs index d8c30873..4abccc0c 100644 --- a/crates/hermesllm/src/providers/mod.rs +++ b/crates/hermesllm/src/providers/mod.rs @@ -1 +1,14 @@ -pub mod openai; +//! Provider implementations for different LLM APIs +//! +//! This module contains provider-specific implementations that handle +//! request/response conversion for different LLM service APIs. +//! +pub mod id; +pub mod request; +pub mod response; +pub mod adapters; + +pub use id::ProviderId; +pub use request::{ProviderRequestType, ProviderRequest, ProviderRequestError} ; +pub use response::{ProviderResponseType, ProviderStreamResponseIter, ProviderResponse, ProviderStreamResponse, TokenUsage }; +pub use adapters::*; diff --git a/crates/hermesllm/src/providers/openai/builder.rs b/crates/hermesllm/src/providers/openai/builder.rs deleted file mode 100644 index fa1f325e..00000000 --- a/crates/hermesllm/src/providers/openai/builder.rs +++ /dev/null @@ -1,114 +0,0 @@ -use serde_json::Value; - -use crate::providers::openai::types::{ChatCompletionsRequest, Message, StreamOptions}; - -#[derive(Debug, Clone)] -pub struct OpenAIRequestBuilder { - model: String, - messages: Vec, - temperature: Option, - top_p: Option, - n: Option, - max_tokens: Option, - stream: Option, - stop: Option>, - presence_penalty: Option, - frequency_penalty: Option, - stream_options: Option, - tools: Option>, -} - -impl OpenAIRequestBuilder { - pub fn new(model: impl Into, messages: Vec) -> Self { - Self { - model: model.into(), - messages, - temperature: None, - top_p: None, - n: None, - max_tokens: None, - stream: None, - stop: None, - presence_penalty: None, - frequency_penalty: None, - stream_options: None, - tools: None, - } - } - - pub fn temperature(mut self, temperature: f32) -> Self { - self.temperature = Some(temperature); - self - } - - pub fn top_p(mut self, top_p: f32) -> Self { - self.top_p = Some(top_p); - self - } - - pub fn n(mut self, n: u32) -> Self { - self.n = Some(n); - self - } - - pub fn max_tokens(mut self, max_tokens: u32) -> Self { - self.max_tokens = Some(max_tokens); - self - } - - pub fn stream(mut self, stream: bool) -> Self { - self.stream = Some(stream); - self - } - - pub fn stop(mut self, stop: Vec) -> Self { - self.stop = Some(stop); - self - } - - pub fn presence_penalty(mut self, presence_penalty: f32) -> Self { - self.presence_penalty = Some(presence_penalty); - self - } - - pub fn frequency_penalty(mut self, frequency_penalty: f32) -> Self { - self.frequency_penalty = Some(frequency_penalty); - self - } - - pub fn stream_options(mut self, include_usage: bool) -> Self { - self.stream = Some(true); - self.stream_options = Some(StreamOptions { include_usage }); - self - } - - pub fn tools(mut self, tools: Vec) -> Self { - self.tools = Some(tools); - self - } - - pub fn build(self) -> Result { - let request = ChatCompletionsRequest { - model: self.model, - messages: self.messages, - temperature: self.temperature, - top_p: self.top_p, - n: self.n, - max_tokens: self.max_tokens, - stream: self.stream, - stop: self.stop, - presence_penalty: self.presence_penalty, - frequency_penalty: self.frequency_penalty, - stream_options: self.stream_options, - tools: self.tools, - metadata: None, - }; - Ok(request) - } -} - -impl ChatCompletionsRequest { - pub fn builder(model: impl Into, messages: Vec) -> OpenAIRequestBuilder { - OpenAIRequestBuilder::new(model, messages) - } -} diff --git a/crates/hermesllm/src/providers/openai/mod.rs b/crates/hermesllm/src/providers/openai/mod.rs deleted file mode 100644 index ab228e50..00000000 --- a/crates/hermesllm/src/providers/openai/mod.rs +++ /dev/null @@ -1,2 +0,0 @@ -pub mod builder; -pub mod types; diff --git a/crates/hermesllm/src/providers/openai/types.rs b/crates/hermesllm/src/providers/openai/types.rs deleted file mode 100644 index 7dea64df..00000000 --- a/crates/hermesllm/src/providers/openai/types.rs +++ /dev/null @@ -1,563 +0,0 @@ -use std::collections::HashMap; -use std::fmt::Display; - -use serde::{Deserialize, Serialize}; -use serde_json::Value; -use serde_with::skip_serializing_none; -use std::convert::TryFrom; -use std::str; -use thiserror::Error; - -use crate::Provider; - -#[derive(Debug, Error)] -pub enum OpenAIError { - #[error("json error: {0}")] - JsonParseError(#[from] serde_json::Error), - #[error("utf8 parsing error: {0}")] - Utf8Error(#[from] std::str::Utf8Error), - #[error("invalid streaming data err {source}, data: {data}")] - InvalidStreamingData { - source: serde_json::Error, - data: String, - }, - #[error("unsupported provider: {provider}")] - UnsupportedProvider { provider: String }, -} - -type Result = std::result::Result; - -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -pub enum MultiPartContentType { - #[serde(rename = "text")] - Text, - #[serde(rename = "image_url")] - ImageUrl, -} - -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -pub struct ImageUrl { - pub url: String, -} - -#[skip_serializing_none] -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -pub struct MultiPartContent { - pub text: Option, - pub image_url: Option, - #[serde(rename = "type")] - pub content_type: MultiPartContentType, -} - -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -#[serde(untagged)] -pub enum ContentType { - Text(String), - MultiPart(Vec), -} - -impl Display for ContentType { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - ContentType::Text(text) => write!(f, "{}", text), - ContentType::MultiPart(multi_part) => { - let text_parts: Vec = multi_part - .iter() - .filter_map(|part| { - if part.content_type == MultiPartContentType::Text { - part.text.clone() - } else if part.content_type == MultiPartContentType::ImageUrl { - // skip image URLs or their data in text representation - None - } else { - panic!("Unsupported content type: {:?}", part.content_type); - } - }) - .collect(); - let combined_text = text_parts.join("\n"); - write!(f, "{}", combined_text) - } - } - } -} - -#[skip_serializing_none] -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Message { - pub role: String, - pub content: Option, -} - -impl Message { - pub fn new(content: String) -> Self { - Self { - role: "user".to_string(), - content: Some(ContentType::Text(content)), - } - } -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct StreamOptions { - pub include_usage: bool, -} - -#[skip_serializing_none] -#[derive(Debug, Clone, Serialize, Deserialize, Default)] -pub struct ChatCompletionsRequest { - pub model: String, - pub messages: Vec, - pub temperature: Option, - pub top_p: Option, - pub n: Option, - pub max_tokens: Option, - pub stream: Option, - pub stop: Option>, - pub presence_penalty: Option, - pub frequency_penalty: Option, - pub stream_options: Option, - pub tools: Option>, - pub metadata: Option>, -} - -impl TryFrom<&[u8]> for ChatCompletionsRequest { - type Error = OpenAIError; - fn try_from(bytes: &[u8]) -> Result { - serde_json::from_slice(bytes).map_err(OpenAIError::from) - } -} - -#[skip_serializing_none] -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct ChatCompletionsResponse { - pub id: String, - pub object: String, - pub created: u64, - pub choices: Vec, - pub usage: Option, -} - -impl TryFrom<&[u8]> for ChatCompletionsResponse { - type Error = OpenAIError; - fn try_from(bytes: &[u8]) -> Result { - serde_json::from_slice(bytes).map_err(OpenAIError::from) - } -} - -impl<'a> TryFrom<(&'a [u8], &'a Provider)> for ChatCompletionsResponse { - type Error = OpenAIError; - - fn try_from(input: (&'a [u8], &'a Provider)) -> Result { - // Use input.provider as needed, if necessary - serde_json::from_slice(input.0).map_err(OpenAIError::from) - } -} - -impl ChatCompletionsRequest { - pub fn to_bytes(&self, provider: Provider) -> Result> { - match provider { - Provider::OpenAI - | Provider::Arch - | Provider::Deepseek - | Provider::Mistral - | Provider::Groq - | Provider::Gemini - | Provider::Claude => serde_json::to_vec(self).map_err(OpenAIError::from), - _ => Err(OpenAIError::UnsupportedProvider { - provider: provider.to_string(), - }), - } - } -} - -#[skip_serializing_none] -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct Choice { - pub index: u32, - pub message: Message, - pub finish_reason: Option, -} - -#[skip_serializing_none] -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct Usage { - pub prompt_tokens: usize, - pub completion_tokens: usize, - pub total_tokens: usize, -} - -#[skip_serializing_none] -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct DeltaMessage { - pub role: Option, - pub content: Option, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct StreamChoice { - pub index: u32, - pub delta: DeltaMessage, - pub finish_reason: Option, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct ChatCompletionStreamResponse { - pub id: String, - pub object: String, - pub created: u64, - pub model: String, - pub choices: Vec, - pub usage: Option, -} - -pub struct SseChatCompletionIter -where - I: Iterator, - I::Item: AsRef, -{ - lines: I, -} - -impl SseChatCompletionIter -where - I: Iterator, - I::Item: AsRef, -{ - pub fn new(lines: I) -> Self { - Self { lines } - } -} - -impl Iterator for SseChatCompletionIter -where - I: Iterator, - I::Item: AsRef, -{ - type Item = Result; - - fn next(&mut self) -> Option { - for line in &mut self.lines { - let line = line.as_ref(); - if let Some(data) = line.strip_prefix("data: ") { - let data = data.trim(); - if data == "[DONE]" { - return None; - } - - if data == r#"{"type": "ping"}"# { - continue; // Skip ping messages - that is usually from anthropic - } - - return Some( - serde_json::from_str::(data).map_err(|e| { - OpenAIError::InvalidStreamingData { - source: e, - data: data.to_string(), - } - }), - ); - } - } - None - } -} - -impl<'a> TryFrom<(&'a [u8], &'a Provider)> for SseChatCompletionIter> { - type Error = OpenAIError; - - fn try_from(input: (&'a [u8], &'a Provider)) -> Result { - let s = std::str::from_utf8(input.0)?; - // Use input.provider as needed - Ok(SseChatCompletionIter::new(s.lines())) - } -} - -impl<'a> TryFrom<&'a [u8]> for SseChatCompletionIter> { - type Error = OpenAIError; - - fn try_from(bytes: &'a [u8]) -> Result { - let s = std::str::from_utf8(bytes)?; - Ok(SseChatCompletionIter::new(s.lines())) - } -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ModelDetail { - pub id: String, - pub object: String, - pub created: usize, - pub owned_by: String, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub enum ModelObject { - #[serde(rename = "list")] - List, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Models { - pub object: ModelObject, - pub data: Vec, -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_content_type_display() { - let text_content = ContentType::Text("Hello, world!".to_string()); - assert_eq!(text_content.to_string(), "Hello, world!"); - - let multi_part_content = ContentType::MultiPart(vec![ - MultiPartContent { - text: Some("This is a text part.".to_string()), - content_type: MultiPartContentType::Text, - image_url: None, - }, - MultiPartContent { - text: Some("https://example.com/image.png".to_string()), - content_type: MultiPartContentType::ImageUrl, - image_url: None, - }, - ]); - assert_eq!(multi_part_content.to_string(), "This is a text part."); - } - - #[test] - fn test_chat_completions_request_text_type_array() { - const CHAT_COMPLETIONS_REQUEST: &str = r#" - { - "model": "gpt-3.5-turbo", - "messages": [ - { - "role": "user", - "content": [ - { - "type": "text", - "text": "What city do you want to know the weather for?" - }, - { - "type": "text", - "text": "hello world" - } - ] - } - ] - } - "#; - - let chat_completions_request: ChatCompletionsRequest = - serde_json::from_str(CHAT_COMPLETIONS_REQUEST).unwrap(); - assert_eq!(chat_completions_request.model, "gpt-3.5-turbo"); - if let Some(ContentType::MultiPart(multi_part_content)) = - chat_completions_request.messages[0].content.as_ref() - { - assert_eq!(multi_part_content.len(), 2); - assert_eq!( - multi_part_content[0].content_type, - MultiPartContentType::Text - ); - assert_eq!( - multi_part_content[0].text, - Some("What city do you want to know the weather for?".to_string()) - ); - assert_eq!( - multi_part_content[1].content_type, - MultiPartContentType::Text - ); - assert_eq!(multi_part_content[1].text, Some("hello world".to_string())); - } else { - panic!("Expected MultiPartContent"); - } - } - - #[test] - fn test_chat_completions_request_image_content() { - const CHAT_COMPLETIONS_REQUEST: &str = r#" - { - "stream": true, - "model": "openai/gpt-4o", - "messages": [ - { - "role": "user", - "content": [ - { - "type": "text", - "text": "describe this photo pls" - }, - { - "type": "image_url", - "image_url": { - "url": "data:image/jpeg;base64,/9j/4AAQSkZJRgABAQAAAQABAAD/...==" - } - } - ] - } - ] - }"#; - - let chat_completions_request: ChatCompletionsRequest = - serde_json::from_str(CHAT_COMPLETIONS_REQUEST).unwrap(); - assert_eq!(chat_completions_request.model, "openai/gpt-4o"); - if let Some(ContentType::MultiPart(multi_part_content)) = - chat_completions_request.messages[0].content.as_ref() - { - assert_eq!(multi_part_content.len(), 2); - assert_eq!( - multi_part_content[0].content_type, - MultiPartContentType::Text - ); - assert_eq!( - multi_part_content[0].text, - Some("describe this photo pls".to_string()) - ); - assert_eq!( - multi_part_content[1].content_type, - MultiPartContentType::ImageUrl - ); - assert_eq!( - multi_part_content[1].image_url, - Some(ImageUrl { - url: "data:image/jpeg;base64,/9j/4AAQSkZJRgABAQAAAQABAAD/...==".to_string(), - }) - ); - } else { - panic!("Expected MultiPartContent"); - } - } - - #[test] - fn test_sse_streaming() { - let json_data = r#"data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1700000000,"model":"gpt-3.5-turbo","choices":[{"index":0,"delta":{"role":"assistant"},"finish_reason":null}]} -data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1700000000,"model":"gpt-3.5-turbo","choices":[{"index":0,"delta":{"content":"Hello, how can I help you today?"},"finish_reason":null}]} -data: [DONE]"#; - - let iter = SseChatCompletionIter::new(json_data.lines()); - - println!("Testing SSE Streaming"); - for item in iter { - match item { - Ok(response) => { - println!("Received response: {:?}", response); - if response.choices.is_empty() { - continue; - } - for choice in response.choices { - if let Some(content) = choice.delta.content { - println!("Content: {}", content); - } - } - } - Err(e) => { - println!("Error parsing JSON: {}", e); - return; - } - } - } - } - - #[test] - fn test_sse_streaming_try_from_bytes() { - let json_data = r#"data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1700000000,"model":"gpt-3.5-turbo","choices":[{"index":0,"delta":{"role":"assistant"},"finish_reason":null}]} -data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1700000000,"model":"gpt-3.5-turbo","choices":[{"index":0,"delta":{"content":"Hello, how can I help you today?"},"finish_reason":null}]} -data: [DONE]"#; - - let iter = SseChatCompletionIter::try_from(json_data.as_bytes()) - .expect("Failed to create SSE iterator"); - - println!("Testing SSE Streaming"); - for item in iter { - match item { - Ok(response) => { - println!("Received response: {:?}", response); - if response.choices.is_empty() { - continue; - } - for choice in response.choices { - if let Some(content) = choice.delta.content { - println!("Content: {}", content); - } - } - } - Err(e) => { - println!("Error parsing JSON: {}", e); - return; - } - } - } - } - - #[test] - fn parse_chat_completions_request() { - const CHAT_COMPLETIONS_REQUEST: &str = r#" -{ - "model": "None", - "messages": [ - { - "role": "user", - "content": "how is the weather in seattle" - } - ], - "stream": true -} "#; - - let _chat_completions_request: ChatCompletionsRequest = - ChatCompletionsRequest::try_from(CHAT_COMPLETIONS_REQUEST.as_bytes()) - .expect("Failed to parse ChatCompletionsRequest"); - } - - #[test] - fn stream_chunk_parse_claude() { - const CHUNK_RESPONSE: &str = r#"data: {"id":"msg_01DZDMxYSgq8aPQxMQoBv6Kb","choices":[{"index":0,"delta":{"role":"assistant"}}],"created":1747685264,"model":"claude-3-7-sonnet-latest","object":"chat.completion.chunk"} - -data: {"type": "ping"} - -data: {"id":"msg_01DZDMxYSgq8aPQxMQoBv6Kb","choices":[{"index":0,"delta":{"content":"Hello!"}}],"created":1747685264,"model":"claude-3-7-sonnet-latest","object":"chat.completion.chunk"} - -data: {"id":"msg_01DZDMxYSgq8aPQxMQoBv6Kb","choices":[{"index":0,"delta":{"content":" How can I assist you today? Whether"}}],"created":1747685264,"model":"claude-3-7-sonnet-latest","object":"chat.completion.chunk"} - -data: {"id":"msg_01DZDMxYSgq8aPQxMQoBv6Kb","choices":[{"index":0,"delta":{"content":" you have a question, need information"}}],"created":1747685264,"model":"claude-3-7-sonnet-latest","object":"chat.completion.chunk"} - -data: {"id":"msg_01DZDMxYSgq8aPQxMQoBv6Kb","choices":[{"index":0,"delta":{"content":", or just want to chat about"}}],"created":1747685264,"model":"claude-3-7-sonnet-latest","object":"chat.completion.chunk"} - -data: {"id":"msg_01DZDMxYSgq8aPQxMQoBv6Kb","choices":[{"index":0,"delta":{"content":" something, I'm here to help. What woul"}}],"created":1747685264,"model":"claude-3-7-sonnet-latest","object":"chat.completion.chunk"} - -data: {"id":"msg_01DZDMxYSgq8aPQxMQoBv6Kb","choices":[{"index":0,"delta":{"content":"d you like to talk about?"}}],"created":1747685264,"model":"claude-3-7-sonnet-latest","object":"chat.completion.chunk"} - -data: {"id":"msg_01DZDMxYSgq8aPQxMQoBv6Kb","choices":[{"index":0,"delta":{},"finish_reason":"stop"}],"created":1747685264,"model":"claude-3-7-sonnet-latest","object":"chat.completion.chunk"} - -data: [DONE] -"#; - - let iter = SseChatCompletionIter::try_from(CHUNK_RESPONSE.as_bytes()); - - assert!(iter.is_ok(), "Failed to create SSE iterator"); - let iter: SseChatCompletionIter> = iter.unwrap(); - - let all_text: Vec = iter - .map(|item| { - let response = item.expect("Failed to parse response"); - response - .choices - .into_iter() - .filter_map(|choice| choice.delta.content) - .map(|content| content.to_string()) - .collect::() - }) - .collect(); - - assert_eq!( - all_text.len(), - 8, - "Expected 8 chunks of text, but got {}", - all_text.len() - ); - - assert_eq!( - all_text.join(""), - "Hello! How can I assist you today? Whether you have a question, need information, or just want to chat about something, I'm here to help. What would you like to talk about?" - ); - } -} diff --git a/crates/hermesllm/src/providers/request.rs b/crates/hermesllm/src/providers/request.rs new file mode 100644 index 00000000..1eb39416 --- /dev/null +++ b/crates/hermesllm/src/providers/request.rs @@ -0,0 +1,115 @@ + +use crate::apis::openai::ChatCompletionsRequest; +use super::{ProviderId, get_provider_config, AdapterType}; +use std::error::Error; +use std::fmt; +pub enum ProviderRequestType { + ChatCompletionsRequest(ChatCompletionsRequest), + //MessagesRequest(MessagesRequest), + //add more request types here +} + +impl TryFrom<&[u8]> for ProviderRequestType { + type Error = std::io::Error; + + // if passing bytes without provider id we assume the request is in OpenAI format + fn try_from(bytes: &[u8]) -> Result { + let chat_completion_request: ChatCompletionsRequest = ChatCompletionsRequest::try_from(bytes) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; + Ok(ProviderRequestType::ChatCompletionsRequest(chat_completion_request)) + } +} + +impl TryFrom<(&[u8], &ProviderId)> for ProviderRequestType { + type Error = std::io::Error; + + fn try_from((bytes, provider_id): (&[u8], &ProviderId)) -> Result { + let config = get_provider_config(provider_id); + match config.adapter_type { + AdapterType::OpenAICompatible => { + let chat_completion_request: ChatCompletionsRequest = ChatCompletionsRequest::try_from(bytes) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; + Ok(ProviderRequestType::ChatCompletionsRequest(chat_completion_request)) + } + // Future: handle other adapter types like Claude + } + } +} + +pub trait ProviderRequest: Send + Sync { + /// Extract the model name from the request + fn model(&self) -> &str; + + /// Set the model name for the request + fn set_model(&mut self, model: String); + + /// Check if this is a streaming request + fn is_streaming(&self) -> bool; + + /// Extract text content from messages for token counting + fn extract_messages_text(&self) -> String; + + /// Extract the user message for tracing/logging purposes + fn get_recent_user_message(&self) -> Option; + + /// Convert the request to bytes for transmission + fn to_bytes(&self) -> Result, ProviderRequestError>; +} + +impl ProviderRequest for ProviderRequestType { + fn model(&self) -> &str { + match self { + Self::ChatCompletionsRequest(r) => r.model(), + } + } + + fn set_model(&mut self, model: String) { + match self { + Self::ChatCompletionsRequest(r) => r.set_model(model), + } + } + + fn is_streaming(&self) -> bool { + match self { + Self::ChatCompletionsRequest(r) => r.is_streaming(), + } + } + + fn extract_messages_text(&self) -> String { + match self { + Self::ChatCompletionsRequest(r) => r.extract_messages_text(), + } + } + + fn get_recent_user_message(&self) -> Option { + match self { + Self::ChatCompletionsRequest(r) => r.get_recent_user_message(), + } + } + + fn to_bytes(&self) -> Result, ProviderRequestError> { + match self { + Self::ChatCompletionsRequest(r) => r.to_bytes(), + } + } +} + + +/// Error types for provider operations +#[derive(Debug)] +pub struct ProviderRequestError { + pub message: String, + pub source: Option>, +} + +impl fmt::Display for ProviderRequestError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "Provider request error: {}", self.message) + } +} + +impl Error for ProviderRequestError { + fn source(&self) -> Option<&(dyn Error + 'static)> { + self.source.as_ref().map(|e| e.as_ref() as &(dyn Error + 'static)) + } +} diff --git a/crates/hermesllm/src/providers/response.rs b/crates/hermesllm/src/providers/response.rs new file mode 100644 index 00000000..faca303f --- /dev/null +++ b/crates/hermesllm/src/providers/response.rs @@ -0,0 +1,167 @@ +use std::error::Error; +use std::fmt; + +use crate::apis::openai::ChatCompletionsResponse; +use crate::apis::OpenAISseIter; +use crate::providers::id::ProviderId; +use crate::providers::adapters::{get_provider_config, AdapterType}; + +pub enum ProviderResponseType { + ChatCompletionsResponse(ChatCompletionsResponse), + //MessagesResponse(MessagesResponse), +} + +pub enum ProviderStreamResponseIter { + ChatCompletionsStream(OpenAISseIter>), + //MessagesStream(AnthropicSseIter>), +} + +impl TryFrom<(&[u8], ProviderId)> for ProviderResponseType { + type Error = std::io::Error; + + fn try_from((bytes, provider_id): (&[u8], ProviderId)) -> Result { + let config = get_provider_config(&provider_id); + match config.adapter_type { + AdapterType::OpenAICompatible => { + let chat_completions_response: ChatCompletionsResponse = ChatCompletionsResponse::try_from(bytes) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; + Ok(ProviderResponseType::ChatCompletionsResponse(chat_completions_response)) + } + // Future: handle other adapter types like Claude + } + } +} + +impl TryFrom<(&[u8], &ProviderId)> for ProviderStreamResponseIter { + type Error = Box; + + fn try_from((bytes, provider_id): (&[u8], &ProviderId)) -> Result { + let config = get_provider_config(provider_id); + + // Parse SSE (Server-Sent Events) streaming data - protocol layer + let s = std::str::from_utf8(bytes)?; + let lines: Vec = s.lines().map(|line| line.to_string()).collect(); + + match config.adapter_type { + AdapterType::OpenAICompatible => { + // Delegate to OpenAI-specific iterator implementation + let sse_container = SseStreamIter::new(lines.into_iter()); + let iter = crate::apis::openai::OpenAISseIter::new(sse_container); + Ok(ProviderStreamResponseIter::ChatCompletionsStream(iter)) + } + // Future: AdapterType::Claude => { + // let sse_container = SseStreamIter::new(lines.into_iter()); + // let iter = crate::apis::anthropic::AnthropicSseIter::new(sse_container); + // Ok(ProviderStreamResponseIter::MessagesStream(iter)) + // } + } + } +} + + +impl Iterator for ProviderStreamResponseIter { + type Item = Result, Box>; + + fn next(&mut self) -> Option { + match self { + ProviderStreamResponseIter::ChatCompletionsStream(iter) => iter.next(), + // Future: ProviderStreamResponseIter::MessagesStream(iter) => iter.next(), + } + } +} + + +pub trait ProviderResponse: Send + Sync { + /// Get usage information if available - returns dynamic trait object + fn usage(&self) -> Option<&dyn TokenUsage>; + + /// Extract token counts for metrics + fn extract_usage_counts(&self) -> Option<(usize, usize, usize)> { + self.usage().map(|u| (u.prompt_tokens(), u.completion_tokens(), u.total_tokens())) + } +} + +pub trait ProviderStreamResponse: Send + Sync { + /// Get the content delta for this chunk + fn content_delta(&self) -> Option<&str>; + + /// Check if this is the final chunk in the stream + fn is_final(&self) -> bool; + + /// Get role information if available + fn role(&self) -> Option<&str>; +} + + + +// ============================================================================ +// GENERIC SSE STREAMING ITERATOR (Container Only) +// ============================================================================ + +/// Generic SSE (Server-Sent Events) streaming iterator container +/// This is just a simple wrapper - actual Iterator implementation is delegated to provider-specific modules +pub struct SseStreamIter +where + I: Iterator, + I::Item: AsRef, +{ + pub lines: I, +} + +impl SseStreamIter +where + I: Iterator, + I::Item: AsRef, +{ + pub fn new(lines: I) -> Self { + Self { lines } + } +} + + +impl ProviderResponse for ProviderResponseType { + fn usage(&self) -> Option<&dyn TokenUsage> { + match self { + ProviderResponseType::ChatCompletionsResponse(resp) => resp.usage(), + // Future: ProviderResponseType::MessagesResponse(resp) => resp.usage(), + } + } + + fn extract_usage_counts(&self) -> Option<(usize, usize, usize)> { + match self { + ProviderResponseType::ChatCompletionsResponse(resp) => resp.extract_usage_counts(), + // Future: ProviderResponseType::MessagesResponse(resp) => resp.extract_usage_counts(), + } + } +} + +// Implement Send + Sync for the enum to match the original trait requirements +unsafe impl Send for ProviderStreamResponseIter {} +unsafe impl Sync for ProviderStreamResponseIter {} + +/// Trait for token usage information +pub trait TokenUsage { + fn completion_tokens(&self) -> usize; + fn prompt_tokens(&self) -> usize; + fn total_tokens(&self) -> usize; +} + + +#[derive(Debug)] +pub struct ProviderResponseError { + pub message: String, + pub source: Option>, +} + + +impl fmt::Display for ProviderResponseError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "Provider response error: {}", self.message) + } +} + +impl Error for ProviderResponseError { + fn source(&self) -> Option<&(dyn Error + 'static)> { + self.source.as_ref().map(|e| e.as_ref() as &(dyn Error + 'static)) + } +} diff --git a/crates/llm_gateway/src/stream_context.rs b/crates/llm_gateway/src/stream_context.rs index 82b88509..6b2c5f15 100644 --- a/crates/llm_gateway/src/stream_context.rs +++ b/crates/llm_gateway/src/stream_context.rs @@ -10,11 +10,10 @@ use common::ratelimit::Header; use common::stats::{IncrementingMetric, RecordingMetric}; use common::tracing::{Event, Span, TraceData, Traceparent}; use common::{ratelimit, routing, tokenizer}; -use hermesllm::providers::openai::types::{ChatCompletionsRequest, SseChatCompletionIter}; -use hermesllm::providers::openai::types::{ - ChatCompletionsResponse, ContentType, Message, StreamOptions, +use hermesllm::providers::response::ProviderStreamResponseIter; +use hermesllm::{ + ProviderId, ProviderRequest, ProviderRequestType, ProviderResponse, ProviderResponseType, }; -use hermesllm::Provider; use http::StatusCode; use log::{debug, info, warn}; use proxy_wasm::hostcalls::get_current_time; @@ -41,9 +40,9 @@ pub struct StreamContext { ttft_time: Option, traceparent: Option, request_body_sent_time: Option, - user_message: Option, traces_queue: Arc>>, overrides: Rc>, + user_message: Option, } impl StreamContext { @@ -69,9 +68,9 @@ impl StreamContext { ttft_duration: None, traceparent: None, ttft_time: None, - user_message: None, traces_queue, request_body_sent_time: None, + user_message: None, } } fn llm_provider(&self) -> &LlmProvider { @@ -80,6 +79,10 @@ impl StreamContext { .expect("the provider should be set when asked for it") } + fn get_provider_id(&self) -> ProviderId { + self.llm_provider().to_provider_id() + } + fn select_llm_provider(&mut self) { let provider_hint = self .get_http_request_header(ARCH_PROVIDER_HINT_HEADER) @@ -295,24 +298,23 @@ impl HttpContext for StreamContext { } }; - let mut deserialized_body = match ChatCompletionsRequest::try_from(body_bytes.as_slice()) { - Ok(deserialized) => deserialized, - Err(e) => { - debug!( - "on_http_request_body: request body: {}", - String::from_utf8_lossy(&body_bytes) - ); - self.send_server_error(ServerError::OpenAIPError(e), Some(StatusCode::BAD_REQUEST)); - return Action::Pause; - } - }; + let provider_id = self.get_provider_id(); - self.user_message = deserialized_body - .messages - .iter() - .filter(|m| m.role == "user") - .last() - .cloned(); + let mut deserialized_body = + match ProviderRequestType::try_from((&body_bytes[..], &provider_id)) { + Ok(deserialized) => deserialized, + Err(e) => { + debug!( + "on_http_request_body: request body: {}", + String::from_utf8_lossy(&body_bytes) + ); + self.send_server_error( + ServerError::LogicError(format!("Request parsing error: {}", e)), + Some(StatusCode::BAD_REQUEST), + ); + return Action::Pause; + } + }; let model_name = match self.llm_provider.as_ref() { Some(llm_provider) => llm_provider.model.as_ref(), @@ -324,24 +326,38 @@ impl HttpContext for StreamContext { None => false, }; - let model_requested = deserialized_body.model.clone(); - deserialized_body.model = match model_name { + // Store the original model for logging + let model_requested = deserialized_body.model().to_string(); + + // Apply model name resolution logic using the trait method + let resolved_model = match model_name { Some(model_name) => model_name.clone(), None => { if use_agent_orchestrator { "agent_orchestrator".to_string() } else { self.send_server_error( - ServerError::BadRequest { - why: format!("No model specified in request and couldn't determine model name from arch_config. Model name in req: {}, arch_config, provider: {}, model: {:?}", deserialized_body.model, self.llm_provider().name, self.llm_provider().model).to_string(), - }, - Some(StatusCode::BAD_REQUEST), - ); + ServerError::BadRequest { + why: format!( + "No model specified in request and couldn't determine model name from arch_config. Model name in req: {}, arch_config, provider: {}, model: {:?}", + model_requested, + self.llm_provider().name, + self.llm_provider().model + ), + }, + Some(StatusCode::BAD_REQUEST), + ); return Action::Continue; } } }; + // Set the resolved model using the trait method + deserialized_body.set_model(resolved_model.clone()); + + // Extract user message for tracing + self.user_message = deserialized_body.get_recent_user_message(); + info!( "on_http_request_body: provider: {}, model requested (in body): {}, model selected: {}", self.llm_provider().name, @@ -349,32 +365,13 @@ impl HttpContext for StreamContext { model_name.unwrap_or(&"None".to_string()), ); - if deserialized_body.stream.unwrap_or_default() { - self.streaming_response = true; - } - if deserialized_body.stream.unwrap_or_default() - && deserialized_body.stream_options.is_none() - { - deserialized_body.stream_options = Some(StreamOptions { - include_usage: true, - }); - } + // Use provider interface for streaming detection and setup + self.streaming_response = deserialized_body.is_streaming(); - // only use the tokens from the messages, excluding the metadata and json tags - let input_tokens_str = deserialized_body - .messages - .iter() - .fold(String::new(), |acc, m| { - acc + " " - + m.content - .as_ref() - .unwrap_or(&ContentType::Text(String::new())) - .to_string() - .as_str() - }); + // Use provider interface for text extraction (after potential mutation) + let input_tokens_str = deserialized_body.extract_messages_text(); // enforce ratelimits on ingress - if let Err(e) = self.enforce_ratelimits(&deserialized_body.model, input_tokens_str.as_str()) - { + if let Err(e) = self.enforce_ratelimits(&resolved_model, input_tokens_str.as_str()) { self.send_server_error( ServerError::ExceededRatelimit(e), Some(StatusCode::TOO_MANY_REQUESTS), @@ -383,15 +380,15 @@ impl HttpContext for StreamContext { return Action::Continue; } - let llm_provider_str = self.llm_provider().provider_interface.to_string(); - let hermes_llm_provider = Provider::from(llm_provider_str.as_str()); - - // convert chat completion request to llm provider specific request - let deserialized_body_bytes = match deserialized_body.to_bytes(hermes_llm_provider) { + // Convert chat completion request to llm provider specific request using provider interface + let deserialized_body_bytes = match deserialized_body.to_bytes() { Ok(bytes) => bytes, Err(e) => { warn!("Failed to serialize request body: {}", e); - self.send_server_error(ServerError::OpenAIPError(e), Some(StatusCode::BAD_REQUEST)); + self.send_server_error( + ServerError::LogicError(format!("Request serialization error: {}", e)), + Some(StatusCode::BAD_REQUEST), + ); return Action::Pause; } }; @@ -484,17 +481,16 @@ impl HttpContext for StreamContext { self.request_body_sent_time.unwrap(), current_time_ns, ); - if let Some(user_message) = self.user_message.as_ref() { - if let Some(prompt) = user_message.content.as_ref() { - llm_span - .add_attribute("user_prompt".to_string(), prompt.to_string()); - } - } llm_span.add_attribute( "model".to_string(), self.llm_provider().name.to_string(), ); + if let Some(user_message) = &self.user_message { + llm_span + .add_attribute("user_message".to_string(), user_message.clone()); + } + if self.ttft_time.is_some() { llm_span.add_event(Event::new( "time_to_first_token".to_string(), @@ -558,62 +554,69 @@ impl HttpContext for StreamContext { ); } - let llm_provider_str = self.llm_provider().provider_interface.to_string(); - let hermes_llm_provider = Provider::from(llm_provider_str.as_str()); - if self.streaming_response { - let chat_completions_chunk_response_events = - match SseChatCompletionIter::try_from((body.as_slice(), &hermes_llm_provider)) { - Ok(events) => events, - Err(e) => { - warn!( - "could not parse response: {}, body str: {}", - e, - String::from_utf8_lossy(&body) - ); - return Action::Continue; - } - }; + debug!("processing streaming response"); + match ProviderStreamResponseIter::try_from((&body[..], &self.get_provider_id())) { + Ok(mut streaming_response) => { + // Process each streaming chunk + while let Some(chunk_result) = streaming_response.next() { + match chunk_result { + Ok(chunk) => { + // Compute TTFT on first chunk + if self.ttft_duration.is_none() { + let current_time = get_current_time().unwrap(); + self.ttft_time = Some(current_time_ns()); + match current_time.duration_since(self.start_time) { + Ok(duration) => { + let duration_ms = duration.as_millis(); + info!( + "on_http_response_body: time to first token: {}ms", + duration_ms + ); + self.ttft_duration = Some(duration); + self.metrics + .time_to_first_token + .record(duration_ms as u64); + } + Err(e) => { + warn!("SystemTime error: {:?}", e); + } + } + } - for event in chat_completions_chunk_response_events { - match event { - Ok(event) => { - if let Some(usage) = event.usage.as_ref() { - self.response_tokens += usage.completion_tokens; + // For streaming responses, we handle token counting differently + // The ProviderStreamResponse trait provides content_delta, is_final, and role + // Token counting for streaming responses typically happens with final usage chunk + if chunk.is_final() { + // For now, we'll implement basic token estimation + // In a complete implementation, the final chunk would contain usage information + debug!("Received final streaming chunk"); + } + + // For now, estimate tokens from content delta + if let Some(content) = chunk.content_delta() { + // Rough estimation: ~4 characters per token + let estimated_tokens = content.len() / 4; + self.response_tokens += estimated_tokens.max(1); + } + } + Err(e) => { + warn!("Error processing streaming chunk: {}", e); + return Action::Continue; + } } } - Err(e) => { - warn!("error in response event: {}", e); - continue; - } } - } - - // Compute TTFT if not already recorded - if self.ttft_duration.is_none() { - // if let Some(start_time) = self.start_time { - let current_time = get_current_time().unwrap(); - self.ttft_time = Some(current_time_ns()); - match current_time.duration_since(self.start_time) { - Ok(duration) => { - let duration_ms = duration.as_millis(); - info!( - "on_http_response_body: time to first token: {}ms", - duration_ms - ); - self.ttft_duration = Some(duration); - self.metrics.time_to_first_token.record(duration_ms as u64); - } - Err(e) => { - warn!("SystemTime error: {:?}", e); - } + Err(e) => { + warn!("Failed to parse streaming response: {}", e); } } } else { debug!("non streaming response"); - let chat_completions_response = - match ChatCompletionsResponse::try_from((body.as_slice(), &hermes_llm_provider)) { - Ok(de) => de, + let provider_id = self.get_provider_id(); + let response: ProviderResponseType = + match ProviderResponseType::try_from((&body[..], provider_id)) { + Ok(response) => response, Err(e) => { warn!( "could not parse response: {}, body str: {}", @@ -626,15 +629,24 @@ impl HttpContext for StreamContext { String::from_utf8_lossy(&body) ); self.send_server_error( - ServerError::OpenAIPError(e), + ServerError::LogicError(format!("Response parsing error: {}", e)), Some(StatusCode::BAD_REQUEST), ); return Action::Continue; } }; - if let Some(usage) = chat_completions_response.usage { - self.response_tokens += usage.completion_tokens; + // Use provider interface to extract usage information + if let Some((prompt_tokens, completion_tokens, total_tokens)) = + response.extract_usage_counts() + { + debug!( + "Response usage: prompt={}, completion={}, total={}", + prompt_tokens, completion_tokens, total_tokens + ); + self.response_tokens = completion_tokens; + } else { + warn!("No usage information found in response"); } } diff --git a/crates/llm_gateway/tests/integration.rs b/crates/llm_gateway/tests/integration.rs index 108ab1ce..82ae8322 100644 --- a/crates/llm_gateway/tests/integration.rs +++ b/crates/llm_gateway/tests/integration.rs @@ -12,7 +12,7 @@ fn wasm_module() -> String { wasm_file.exists(), "Run `cargo build --release --target=wasm32-wasip1` first" ); - wasm_file.to_str().unwrap().to_string() + wasm_file.to_string_lossy().to_string() } fn request_headers_expectations(module: &mut Tester, http_context: i32) { @@ -267,17 +267,12 @@ fn llm_gateway_bad_request_to_open_ai_chat_completions() { .returning(Some(incomplete_chat_completions_request_body)) .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Info), Some("on_http_request_body: provider: open-ai-gpt-4, model requested (in body): gpt-1, model selected: gpt-4")) - .expect_send_local_response( - Some(StatusCode::BAD_REQUEST.as_u16().into()), - None, - None, - None, - ) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Debug), Some("getting token count model=gpt-4")) + .expect_log(Some(LogLevel::Debug), Some("Recorded input token count: 13")) .expect_metric_record("input_sequence_length", 13) - .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Debug), Some("Applying ratelimit for model: gpt-4")) + .expect_log(Some(LogLevel::Debug), Some(r#"Checking limit for provider=gpt-4, with selector=Header { key: "selector-key", value: "selector-value" }, consuming tokens=13"#)) + .expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None) .execute_and_expect(ReturnType::Action(Action::Continue)) .unwrap(); } @@ -386,11 +381,11 @@ fn llm_gateway_request_not_ratelimited() { .returning(Some(chat_completions_request_body)) // The actual call is not important in this test, we just need to grab the token_id .expect_log(Some(LogLevel::Info), None) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Debug), Some("getting token count model=gpt-4")) + .expect_log(Some(LogLevel::Debug), Some("Recorded input token count: 29")) .expect_metric_record("input_sequence_length", 29) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Debug), Some("Applying ratelimit for model: gpt-4")) + .expect_log(Some(LogLevel::Debug), Some(r#"Checking limit for provider=gpt-4, with selector=Header { key: "selector-key", value: "selector-value" }, consuming tokens=29"#)) .expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None) .execute_and_expect(ReturnType::Action(Action::Continue)) .unwrap(); @@ -433,11 +428,11 @@ fn llm_gateway_override_model_name() { // The actual call is not important in this test, we just need to grab the token_id .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Info), Some("on_http_request_body: provider: open-ai-gpt-4, model requested (in body): gpt-1, model selected: gpt-4")) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Debug), Some("getting token count model=gpt-4")) + .expect_log(Some(LogLevel::Debug), Some("Recorded input token count: 29")) .expect_metric_record("input_sequence_length", 29) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Debug), Some("Applying ratelimit for model: gpt-4")) + .expect_log(Some(LogLevel::Debug), Some(r#"Checking limit for provider=gpt-4, with selector=Header { key: "selector-key", value: "selector-value" }, consuming tokens=29"#)) .expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None) .execute_and_expect(ReturnType::Action(Action::Continue)) .unwrap(); @@ -483,8 +478,8 @@ fn llm_gateway_override_use_default_model() { Some(LogLevel::Info), Some("on_http_request_body: provider: open-ai-gpt-4, model requested (in body): gpt-1, model selected: gpt-4"), ) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Debug), Some("getting token count model=gpt-4")) + .expect_log(Some(LogLevel::Debug), Some("Recorded input token count: 29")) .expect_metric_record("input_sequence_length", 29) .expect_log(Some(LogLevel::Debug), Some("Applying ratelimit for model: gpt-4")) .expect_log(Some(LogLevel::Debug), Some(r#"Checking limit for provider=gpt-4, with selector=Header { key: "selector-key", value: "selector-value" }, consuming tokens=29"#)) @@ -530,11 +525,11 @@ fn llm_gateway_override_use_model_name_none() { // The actual call is not important in this test, we just need to grab the token_id .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Info), Some("on_http_request_body: provider: open-ai-gpt-4, model requested (in body): none, model selected: gpt-4")) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Debug), Some("getting token count model=gpt-4")) + .expect_log(Some(LogLevel::Debug), Some("Recorded input token count: 29")) .expect_metric_record("input_sequence_length", 29) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Debug), Some("Applying ratelimit for model: gpt-4")) + .expect_log(Some(LogLevel::Debug), Some(r#"Checking limit for provider=gpt-4, with selector=Header { key: "selector-key", value: "selector-value" }, consuming tokens=29"#)) .expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None) .execute_and_expect(ReturnType::Action(Action::Continue)) .unwrap();