From e73a9eb61cb6f99fd8b381ec72ab6b1828954f28 Mon Sep 17 00:00:00 2001 From: Salman Paracha Date: Fri, 22 Aug 2025 14:36:46 -0700 Subject: [PATCH] transformations are working. Now need to add some tests next --- crates/hermesllm/src/clients/endpoints.rs | 33 ---- crates/hermesllm/src/lib.rs | 28 +-- crates/hermesllm/src/providers/id.rs | 18 ++ crates/hermesllm/src/providers/request.rs | 43 +---- crates/hermesllm/src/providers/response.rs | 59 +++--- crates/llm_gateway/src/stream_context.rs | 208 +++++++++++---------- 6 files changed, 182 insertions(+), 207 deletions(-) diff --git a/crates/hermesllm/src/clients/endpoints.rs b/crates/hermesllm/src/clients/endpoints.rs index a834969b..4f73221f 100644 --- a/crates/hermesllm/src/clients/endpoints.rs +++ b/crates/hermesllm/src/clients/endpoints.rs @@ -31,14 +31,6 @@ pub enum SupportedApi { } impl SupportedApi { - /// Determine if a request/response conversion is required for the given model string - pub fn requires_conversion_for_model(&self, model: &str) -> bool { - use crate::providers::adapters::is_claude_family; - match self { - SupportedApi::Anthropic(AnthropicApi::Messages) => !is_claude_family(model), - SupportedApi::OpenAI(OpenAIApi::ChatCompletions) => is_claude_family(model), - } - } /// Create a SupportedApi from an endpoint path pub fn from_endpoint(endpoint: &str) -> Option { if let Some(openai_api) = OpenAIApi::from_endpoint(endpoint) { @@ -60,14 +52,6 @@ impl SupportedApi { } } - /// Get the API family name - pub fn api_family(&self) -> &'static str { - match self { - SupportedApi::OpenAI(_) => "openai", - SupportedApi::Anthropic(_) => "anthropic", - } - } - /// Determine the target endpoint for a given provider /// For /v1/messages: if provider is Anthropic, use /v1/messages; otherwise use /v1/chat/completions pub fn target_endpoint_for_provider(&self, provider: &str) -> &'static str { @@ -83,23 +67,6 @@ impl SupportedApi { _ => self.endpoint() } } - - /// Check if request conversion is required for the given provider - /// True if we need to convert between Anthropic and OpenAI formats - pub fn requires_conversion(&self, provider: &str) -> bool { - match self { - SupportedApi::Anthropic(AnthropicApi::Messages) => { - // If provider is not Anthropic/Claude, we need to convert to OpenAI format - !(provider.to_lowercase().contains("anthropic") || - provider.to_lowercase().contains("claude")) - } - SupportedApi::OpenAI(OpenAIApi::ChatCompletions) => { - // If provider is Anthropic/Claude but request is OpenAI format, need conversion - provider.to_lowercase().contains("anthropic") || - provider.to_lowercase().contains("claude") - } - } - } } diff --git a/crates/hermesllm/src/lib.rs b/crates/hermesllm/src/lib.rs index b4ad9932..89955b73 100644 --- a/crates/hermesllm/src/lib.rs +++ b/crates/hermesllm/src/lib.rs @@ -66,28 +66,30 @@ mod tests { #[test] fn test_provider_streaming_response() { // Test streaming response parsing with sample SSE data - let sse_data = r#"data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268190,"model":"gpt-4","choices":[{"index":0,"delta":{"role":"assistant","content":"Hello"},"finish_reason":null}]} + let sse_data = r#"data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268190,"model":"gpt-4","choices":[{"index":0,"delta":{"role":"assistant","content":"Hello"},"finish_reason":null}]} data: [DONE] "#; - let result = ProviderStreamResponseIter::try_from((sse_data.as_bytes(), &ProviderId::OpenAI)); - assert!(result.is_ok()); + use crate::clients::endpoints::SupportedApi; + let api = SupportedApi::OpenAI(crate::apis::OpenAIApi::ChatCompletions); + let result = ProviderStreamResponseIter::try_from((sse_data.as_bytes(), &api, &ProviderId::OpenAI)); + assert!(result.is_ok()); - let mut streaming_response = result.unwrap(); + let mut streaming_response = result.unwrap(); - // Test that we can iterate over chunks - it's just an iterator now! - let first_chunk = streaming_response.next(); - assert!(first_chunk.is_some()); + // Test that we can iterate over chunks - it's just an iterator now! + let first_chunk = streaming_response.next(); + assert!(first_chunk.is_some()); - let chunk_result = first_chunk.unwrap(); - assert!(chunk_result.is_ok()); + let chunk_result = first_chunk.unwrap(); + assert!(chunk_result.is_ok()); - let chunk = chunk_result.unwrap(); - assert_eq!(chunk.content_delta(), Some("Hello")); - assert!(!chunk.is_final()); + let chunk = chunk_result.unwrap(); + assert_eq!(chunk.content_delta(), Some("Hello")); + assert!(!chunk.is_final()); - // Test that stream ends properly + // Test that stream ends properly let final_chunk = streaming_response.next(); assert!(final_chunk.is_none()); } diff --git a/crates/hermesllm/src/providers/id.rs b/crates/hermesllm/src/providers/id.rs index 2c0c494e..f83a9104 100644 --- a/crates/hermesllm/src/providers/id.rs +++ b/crates/hermesllm/src/providers/id.rs @@ -1,4 +1,6 @@ use std::fmt::Display; +use crate::clients::endpoints::SupportedApi; +use crate::apis::{OpenAIApi, AnthropicApi}; /// Provider identifier enum - simple enum for identifying providers #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] @@ -29,6 +31,22 @@ impl From<&str> for ProviderId { } } +impl ProviderId { + /// Given a client API, return the compatible upstream API for this provider + pub fn compatible_api_for_client(&self, client_api: &SupportedApi) -> SupportedApi { + match (self, client_api) { + // Claude/Anthropic providers natively support Anthropic APIs + (ProviderId::Claude, SupportedApi::Anthropic(_)) => client_api.clone(), + // Claude/Anthropic providers can also support OpenAI chat completions by mapping to Anthropic Messages + (ProviderId::Claude, SupportedApi::OpenAI(OpenAIApi::ChatCompletions)) => SupportedApi::Anthropic(AnthropicApi::Messages), + + // OpenAI-compatible providers only support OpenAI chat completions + (ProviderId::OpenAI | ProviderId::Groq | ProviderId::Mistral | ProviderId::Deepseek | ProviderId::Arch | ProviderId::Gemini | ProviderId::GitHub, SupportedApi::Anthropic(_)) => SupportedApi::OpenAI(OpenAIApi::ChatCompletions), + (ProviderId::OpenAI | ProviderId::Groq | ProviderId::Mistral | ProviderId::Deepseek | ProviderId::Arch | ProviderId::Gemini | ProviderId::GitHub, SupportedApi::OpenAI(_)) => SupportedApi::OpenAI(OpenAIApi::ChatCompletions), + } + } +} + impl Display for ProviderId { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { diff --git a/crates/hermesllm/src/providers/request.rs b/crates/hermesllm/src/providers/request.rs index e0d82566..b83a8d23 100644 --- a/crates/hermesllm/src/providers/request.rs +++ b/crates/hermesllm/src/providers/request.rs @@ -2,7 +2,6 @@ use crate::apis::openai::ChatCompletionsRequest; use crate::apis::anthropic::MessagesRequest; use crate::clients::endpoints::SupportedApi; -use super::{ProviderId, get_provider_config, AdapterType}; use std::error::Error; use std::fmt; pub enum ProviderRequestType { @@ -22,53 +21,23 @@ impl TryFrom<&[u8]> for ProviderRequestType { } } -impl TryFrom<(&[u8], &ProviderId)> for ProviderRequestType { +/// Parse request based on endpoint and provider information +impl TryFrom<(&[u8], &SupportedApi)> for ProviderRequestType { type Error = std::io::Error; - fn try_from((bytes, provider_id): (&[u8], &ProviderId)) -> Result { - let config = get_provider_config(provider_id); - match config.adapter_type { - AdapterType::OpenAICompatible => { + fn try_from((bytes, endpoint): (&[u8], &SupportedApi)) -> Result { + // Use SupportedApi to determine the appropriate request type + match endpoint { + SupportedApi::OpenAI(_) => { let chat_completion_request: ChatCompletionsRequest = ChatCompletionsRequest::try_from(bytes) .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; Ok(ProviderRequestType::ChatCompletionsRequest(chat_completion_request)) - } - AdapterType::AnthropicCompatible => { - // For Anthropic providers, try to parse as MessagesRequest first, fallback to ChatCompletionsRequest - if let Ok(messages_request) = MessagesRequest::try_from(bytes) { - Ok(ProviderRequestType::MessagesRequest(messages_request)) - } else { - let chat_completion_request: ChatCompletionsRequest = ChatCompletionsRequest::try_from(bytes) - .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; - Ok(ProviderRequestType::ChatCompletionsRequest(chat_completion_request)) - } - } - } - } -} - -/// Parse request based on endpoint and provider information -impl TryFrom<(&[u8], &str, &ProviderId)> for ProviderRequestType { - type Error = std::io::Error; - - fn try_from((bytes, endpoint, provider_id): (&[u8], &str, &ProviderId)) -> Result { - // Use SupportedApi to determine the appropriate request type - if let Some(api) = SupportedApi::from_endpoint(endpoint) { - match api { - SupportedApi::OpenAI(_) => { - let chat_completion_request: ChatCompletionsRequest = ChatCompletionsRequest::try_from(bytes) - .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; - Ok(ProviderRequestType::ChatCompletionsRequest(chat_completion_request)) } SupportedApi::Anthropic(_) => { let messages_request: MessagesRequest = MessagesRequest::try_from(bytes) .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; Ok(ProviderRequestType::MessagesRequest(messages_request)) } - } - } else { - // Fallback to provider-based parsing for unsupported endpoints - Self::try_from((bytes, provider_id)) } } } diff --git a/crates/hermesllm/src/providers/response.rs b/crates/hermesllm/src/providers/response.rs index 56410540..2811ab66 100644 --- a/crates/hermesllm/src/providers/response.rs +++ b/crates/hermesllm/src/providers/response.rs @@ -1,11 +1,15 @@ +use crate::providers::id::ProviderId; + +use serde::Serialize; use std::error::Error; use std::fmt; use crate::apis::openai::ChatCompletionsResponse; use crate::apis::OpenAISseIter; -use crate::providers::id::ProviderId; -use crate::providers::adapters::{get_provider_config, AdapterType}; +use crate::clients::endpoints::SupportedApi; +use std::convert::TryFrom; +#[derive(Serialize)] pub enum ProviderResponseType { ChatCompletionsResponse(ChatCompletionsResponse), //MessagesResponse(MessagesResponse), @@ -16,51 +20,50 @@ pub enum ProviderStreamResponseIter { //MessagesStream(AnthropicSseIter>), } -impl TryFrom<(&[u8], ProviderId)> for ProviderResponseType { + +// --- Response transformation logic for client API compatibility --- +impl TryFrom<(&[u8], &SupportedApi, &ProviderId)> for ProviderResponseType { type Error = std::io::Error; - fn try_from((bytes, provider_id): (&[u8], ProviderId)) -> Result { - let config = get_provider_config(&provider_id); - match config.adapter_type { - AdapterType::OpenAICompatible => { - let chat_completions_response: ChatCompletionsResponse = ChatCompletionsResponse::try_from(bytes) + fn try_from((bytes, client_api, provider_id): (&[u8], &SupportedApi, &ProviderId)) -> Result { + let upstream_api = provider_id.compatible_api_for_client(client_api); + match (&upstream_api, client_api) { + (SupportedApi::OpenAI(_), SupportedApi::OpenAI(_)) => { + let resp: ChatCompletionsResponse = ChatCompletionsResponse::try_from(bytes) .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; - Ok(ProviderResponseType::ChatCompletionsResponse(chat_completions_response)) + Ok(ProviderResponseType::ChatCompletionsResponse(resp)) } - AdapterType::AnthropicCompatible => { - // TODO: Implement MessagesResponse parsing for Anthropic-compatible providers - todo!("AnthropicCompatible response parsing not yet implemented"); + (SupportedApi::OpenAI(_), SupportedApi::Anthropic(_)) => { + // If you add a MessagesResponse variant, return it here. For now, just error or serialize as needed. + Err(std::io::Error::new(std::io::ErrorKind::Other, "Anthropic response variant not implemented")) } + _ => Err(std::io::Error::new(std::io::ErrorKind::Other, "Unsupported response transformation")), } } } -impl TryFrom<(&[u8], &ProviderId)> for ProviderStreamResponseIter { +impl TryFrom<(&[u8], &SupportedApi, &ProviderId)> for ProviderStreamResponseIter { type Error = Box; - fn try_from((bytes, provider_id): (&[u8], &ProviderId)) -> Result { - let config = get_provider_config(provider_id); - - // Parse SSE (Server-Sent Events) streaming data - protocol layer - let s = std::str::from_utf8(bytes)?; - let lines: Vec = s.lines().map(|line| line.to_string()).collect(); - - match config.adapter_type { - AdapterType::OpenAICompatible => { - // Delegate to OpenAI-specific iterator implementation - let sse_container = SseStreamIter::new(lines.into_iter()); + fn try_from((bytes, client_api, provider_id): (&[u8], &SupportedApi, &ProviderId)) -> Result { + let upstream_api = provider_id.compatible_api_for_client(client_api); + match (&upstream_api, client_api) { + (SupportedApi::OpenAI(_), SupportedApi::OpenAI(_)) => { + let s = std::str::from_utf8(bytes)?; + let lines: Vec = s.lines().map(|line| line.to_string()).collect(); + let sse_container = crate::providers::response::SseStreamIter::new(lines.into_iter()); let iter = crate::apis::openai::OpenAISseIter::new(sse_container); Ok(ProviderStreamResponseIter::ChatCompletionsStream(iter)) } - AdapterType::AnthropicCompatible => { - // TODO: Implement Anthropic streaming support - todo!("AnthropicCompatible streaming not yet implemented"); + (SupportedApi::OpenAI(_), SupportedApi::Anthropic(_)) => { + // TODO: Implement streaming transformation from OpenAI to Anthropic + Err("Anthropic streaming response variant not implemented".into()) } + _ => Err("Unsupported streaming response transformation".into()), } } } - impl Iterator for ProviderStreamResponseIter { type Item = Result, Box>; diff --git a/crates/llm_gateway/src/stream_context.rs b/crates/llm_gateway/src/stream_context.rs index cc769c7a..ba301eda 100644 --- a/crates/llm_gateway/src/stream_context.rs +++ b/crates/llm_gateway/src/stream_context.rs @@ -12,9 +12,7 @@ use common::tracing::{Event, Span, TraceData, Traceparent}; use common::{ratelimit, routing, tokenizer}; use hermesllm::clients::endpoints::SupportedApi; use hermesllm::providers::response::ProviderStreamResponseIter; -use hermesllm::{ - ProviderId, ProviderRequest, ProviderRequestType, ProviderResponse, ProviderResponseType, -}; +use hermesllm::{ProviderId, ProviderRequest, ProviderRequestType, ProviderResponseType}; use http::StatusCode; use log::{debug, info, warn}; use proxy_wasm::hostcalls::get_current_time; @@ -33,6 +31,8 @@ pub struct StreamContext { streaming_response: bool, response_tokens: usize, supported_api: Option, + /// The API that should be used for the upstream provider (after compatibility mapping) + resolved_api: Option, llm_providers: Rc, llm_provider: Option>, request_id: Option, @@ -62,6 +62,7 @@ impl StreamContext { streaming_response: false, response_tokens: 0, supported_api: None, + resolved_api: None, llm_providers, llm_provider: None, request_id: None, @@ -223,6 +224,16 @@ impl HttpContext for StreamContext { let supported_api = SupportedApi::from_endpoint(&request_path); self.supported_api = supported_api; + // Determine the resolved (upstream) API using provider compatibility + if let (Some(api), Some(provider)) = + (self.supported_api.as_ref(), self.llm_provider.as_ref()) + { + let provider_id = provider.to_provider_id(); + self.resolved_api = Some(provider_id.compatible_api_for_client(api)); + } else { + self.resolved_api = None; + } + let use_agent_orchestrator = match self.overrides.as_ref() { Some(overrides) => overrides.use_agent_orchestrator.unwrap_or_default(), None => false, @@ -340,22 +351,26 @@ impl HttpContext for StreamContext { } }; - let provider_id = self.get_provider_id(); - let request_path = self.get_http_request_header(":path").unwrap_or_default(); - - let mut deserialized_body = match ProviderRequestType::try_from(( - &body_bytes[..], - request_path.as_str(), - &provider_id, - )) { - Ok(deserialized) => deserialized, - Err(e) => { - debug!( - "on_http_request_body: request body: {}", - String::from_utf8_lossy(&body_bytes) - ); + let mut deserialized_body = match self.resolved_api.as_ref() { + Some(resolved_api) => { + match ProviderRequestType::try_from((&body_bytes[..], resolved_api)) { + Ok(deserialized) => deserialized, + Err(e) => { + debug!( + "on_http_request_body: request body: {}", + String::from_utf8_lossy(&body_bytes) + ); + self.send_server_error( + ServerError::LogicError(format!("Request parsing error: {}", e)), + Some(StatusCode::BAD_REQUEST), + ); + return Action::Pause; + } + } + } + None => { self.send_server_error( - ServerError::LogicError(format!("Request parsing error: {}", e)), + ServerError::LogicError("No resolved API for provider".to_string()), Some(StatusCode::BAD_REQUEST), ); return Action::Pause; @@ -603,99 +618,100 @@ impl HttpContext for StreamContext { ); } + let provider_id = self.get_provider_id(); + let supported_api = self.supported_api.as_ref(); + if self.streaming_response { debug!("processing streaming response"); - match ProviderStreamResponseIter::try_from((&body[..], &self.get_provider_id())) { - Ok(mut streaming_response) => { - // Process each streaming chunk - while let Some(chunk_result) = streaming_response.next() { - match chunk_result { - Ok(chunk) => { - // Compute TTFT on first chunk - if self.ttft_duration.is_none() { - let current_time = get_current_time().unwrap(); - self.ttft_time = Some(current_time_ns()); - match current_time.duration_since(self.start_time) { - Ok(duration) => { - let duration_ms = duration.as_millis(); - info!( - "on_http_response_body: time to first token: {}ms", - duration_ms - ); - self.ttft_duration = Some(duration); - self.metrics - .time_to_first_token - .record(duration_ms as u64); + match (supported_api, self.resolved_api.as_ref()) { + (Some(supported_api), Some(_)) => { + match ProviderStreamResponseIter::try_from(( + &body[..], + supported_api, + &provider_id, + )) { + Ok(mut streaming_response) => { + while let Some(chunk_result) = streaming_response.next() { + match chunk_result { + Ok(chunk) => { + if self.ttft_duration.is_none() { + let current_time = get_current_time().unwrap(); + self.ttft_time = Some(current_time_ns()); + match current_time.duration_since(self.start_time) { + Ok(duration) => { + let duration_ms = duration.as_millis(); + info!("on_http_response_body: time to first token: {}ms", duration_ms); + self.ttft_duration = Some(duration); + self.metrics + .time_to_first_token + .record(duration_ms as u64); + } + Err(e) => { + warn!("SystemTime error: {:?}", e); + } + } } - Err(e) => { - warn!("SystemTime error: {:?}", e); + if chunk.is_final() { + debug!("Received final streaming chunk"); + } + if let Some(content) = chunk.content_delta() { + let estimated_tokens = content.len() / 4; + self.response_tokens += estimated_tokens.max(1); } } - } - - // For streaming responses, we handle token counting differently - // The ProviderStreamResponse trait provides content_delta, is_final, and role - // Token counting for streaming responses typically happens with final usage chunk - if chunk.is_final() { - // For now, we'll implement basic token estimation - // In a complete implementation, the final chunk would contain usage information - debug!("Received final streaming chunk"); - } - - // For now, estimate tokens from content delta - if let Some(content) = chunk.content_delta() { - // Rough estimation: ~4 characters per token - let estimated_tokens = content.len() / 4; - self.response_tokens += estimated_tokens.max(1); + Err(e) => { + warn!("Error processing streaming chunk: {}", e); + return Action::Continue; + } } } - Err(e) => { - warn!("Error processing streaming chunk: {}", e); - return Action::Continue; - } + } + Err(e) => { + warn!("Failed to parse streaming response: {}", e); } } } - Err(e) => { - warn!("Failed to parse streaming response: {}", e); + _ => { + warn!("Missing supported_api or resolved_api for streaming response"); } } } else { debug!("non streaming response"); - let provider_id = self.get_provider_id(); - let response: ProviderResponseType = - match ProviderResponseType::try_from((&body[..], provider_id)) { - Ok(response) => response, - Err(e) => { - warn!( - "could not parse response: {}, body str: {}", - e, - String::from_utf8_lossy(&body) - ); - debug!( - "on_http_response_body: S[{}], response body: {}", - self.context_id, - String::from_utf8_lossy(&body) - ); - self.send_server_error( - ServerError::LogicError(format!("Response parsing error: {}", e)), - Some(StatusCode::BAD_REQUEST), - ); - return Action::Continue; + match (supported_api, self.resolved_api.as_ref()) { + (Some(supported_api), Some(_)) => { + match ProviderResponseType::try_from((&body[..], supported_api, &provider_id)) { + Ok(response) => match serde_json::to_vec(&response) { + Ok(bytes) => { + self.set_http_response_body(0, bytes.len(), &bytes); + } + Err(e) => { + self.send_server_error( + ServerError::LogicError(format!( + "Response serialization error: {}", + e + )), + Some(StatusCode::BAD_REQUEST), + ); + return Action::Continue; + } + }, + Err(e) => { + warn!( + "could not parse response: {}, body str: {}", + e, + String::from_utf8_lossy(&body) + ); + self.send_server_error( + ServerError::LogicError(format!("Response parsing error: {}", e)), + Some(StatusCode::BAD_REQUEST), + ); + return Action::Continue; + } } - }; - - // Use provider interface to extract usage information - if let Some((prompt_tokens, completion_tokens, total_tokens)) = - response.extract_usage_counts() - { - debug!( - "Response usage: prompt={}, completion={}, total={}", - prompt_tokens, completion_tokens, total_tokens - ); - self.response_tokens = completion_tokens; - } else { - warn!("No usage information found in response"); + } + _ => { + warn!("Missing supported_api or resolved_api for non-streaming response"); + } } }