From 0aa924309326e574bbca8acc23cb1299cf204f65 Mon Sep 17 00:00:00 2001 From: Salman Paracha Date: Thu, 21 Aug 2025 22:24:07 -0700 Subject: [PATCH] pushing draft PR --- crates/common/src/consts.rs | 1 + crates/hermesllm/src/apis/anthropic.rs | 83 +++++++++++++++ crates/hermesllm/src/clients/endpoints.rs | 111 +++++++++++++++++---- crates/hermesllm/src/clients/mod.rs | 2 +- crates/hermesllm/src/providers/adapters.rs | 22 +++- crates/hermesllm/src/providers/request.rs | 47 ++++++++- crates/hermesllm/src/providers/response.rs | 14 +-- crates/llm_gateway/src/stream_context.rs | 95 +++++++++++++----- 8 files changed, 319 insertions(+), 56 deletions(-) diff --git a/crates/common/src/consts.rs b/crates/common/src/consts.rs index 3ff2ce5e..0eb5a036 100644 --- a/crates/common/src/consts.rs +++ b/crates/common/src/consts.rs @@ -12,6 +12,7 @@ pub const ARCH_ROUTING_HEADER: &str = "x-arch-llm-provider"; pub const MESSAGES_KEY: &str = "messages"; pub const ARCH_PROVIDER_HINT_HEADER: &str = "x-arch-llm-provider-hint"; pub const CHAT_COMPLETIONS_PATH: &str = "/v1/chat/completions"; +pub const MESSAGES_PATH: &str = "/v1/messages"; pub const HEALTHZ_PATH: &str = "/healthz"; pub const X_ARCH_STATE_HEADER: &str = "x-arch-state"; pub const X_ARCH_API_RESPONSE: &str = "x-arch-api-response-message"; diff --git a/crates/hermesllm/src/apis/anthropic.rs b/crates/hermesllm/src/apis/anthropic.rs index 0ffe4e8d..96d4eaa9 100644 --- a/crates/hermesllm/src/apis/anthropic.rs +++ b/crates/hermesllm/src/apis/anthropic.rs @@ -4,6 +4,7 @@ use serde_with::skip_serializing_none; use std::collections::HashMap; use super::ApiDefinition; +use crate::providers::request::{ProviderRequest, ProviderRequestError}; // Enum for all supported Anthropic APIs #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] @@ -369,6 +370,88 @@ impl MessagesRequest { } } +impl TryFrom<&[u8]> for MessagesRequest { + type Error = serde_json::Error; + + fn try_from(bytes: &[u8]) -> Result { + serde_json::from_slice(bytes) + } +} + +impl ProviderRequest for MessagesRequest { + fn model(&self) -> &str { + &self.model + } + + fn set_model(&mut self, model: String) { + self.model = model; + } + + fn is_streaming(&self) -> bool { + self.stream.unwrap_or(false) + } + + fn extract_messages_text(&self) -> String { + let mut text_parts = Vec::new(); + + // Include system prompt if present + if let Some(system) = &self.system { + match system { + MessagesSystemPrompt::Single(s) => text_parts.push(s.clone()), + MessagesSystemPrompt::Blocks(blocks) => { + for block in blocks { + if let MessagesContentBlock::Text { text } = block { + text_parts.push(text.clone()); + } + } + } + } + } + + // Extract text from all messages + for message in &self.messages { + match &message.content { + MessagesMessageContent::Single(text) => text_parts.push(text.clone()), + MessagesMessageContent::Blocks(blocks) => { + for block in blocks { + if let MessagesContentBlock::Text { text } = block { + text_parts.push(text.clone()); + } + } + } + } + } + + text_parts.join(" ") + } + + fn get_recent_user_message(&self) -> Option { + // Find the most recent user message + for message in self.messages.iter().rev() { + if message.role == MessagesRole::User { + match &message.content { + MessagesMessageContent::Single(text) => return Some(text.clone()), + MessagesMessageContent::Blocks(blocks) => { + for block in blocks { + if let MessagesContentBlock::Text { text } = block { + return Some(text.clone()); + } + } + } + } + } + } + None + } + + fn to_bytes(&self) -> Result, ProviderRequestError> { + serde_json::to_vec(self).map_err(|e| ProviderRequestError { + message: format!("Failed to serialize MessagesRequest: {}", e), + source: Some(Box::new(e)), + }) + } +} + impl MessagesResponse { pub fn api_type() -> AnthropicApi { AnthropicApi::Messages diff --git a/crates/hermesllm/src/clients/endpoints.rs b/crates/hermesllm/src/clients/endpoints.rs index bf0648a9..a834969b 100644 --- a/crates/hermesllm/src/clients/endpoints.rs +++ b/crates/hermesllm/src/clients/endpoints.rs @@ -6,12 +6,13 @@ //! # Examples //! //! ```rust -//! use hermesllm::clients::endpoints::{is_supported_endpoint, supported_endpoints}; +//! use hermesllm::clients::endpoints::supported_endpoints; //! //! // Check if we support an endpoint -//! assert!(is_supported_endpoint("/v1/chat/completions")); -//! assert!(is_supported_endpoint("/v1/messages")); -//! assert!(!is_supported_endpoint("/v1/unknown")); +//! use hermesllm::clients::endpoints::SupportedApi; +//! assert!(SupportedApi::from_endpoint("/v1/chat/completions").is_some()); +//! assert!(SupportedApi::from_endpoint("/v1/messages").is_some()); +//! assert!(!SupportedApi::from_endpoint("/v1/unknown").is_some()); //! //! // Get all supported endpoints //! let endpoints = supported_endpoints(); @@ -22,21 +23,87 @@ use crate::apis::{AnthropicApi, OpenAIApi, ApiDefinition}; -/// Check if the given endpoint path is supported -pub fn is_supported_endpoint(endpoint: &str) -> bool { - // Try OpenAI APIs - if OpenAIApi::from_endpoint(endpoint).is_some() { - return true; - } - - // Try Anthropic APIs - if AnthropicApi::from_endpoint(endpoint).is_some() { - return true; - } - - false +/// Unified enum representing all supported API endpoints across providers +#[derive(Debug, Clone, PartialEq)] +pub enum SupportedApi { + OpenAI(OpenAIApi), + Anthropic(AnthropicApi), } +impl SupportedApi { + /// Determine if a request/response conversion is required for the given model string + pub fn requires_conversion_for_model(&self, model: &str) -> bool { + use crate::providers::adapters::is_claude_family; + match self { + SupportedApi::Anthropic(AnthropicApi::Messages) => !is_claude_family(model), + SupportedApi::OpenAI(OpenAIApi::ChatCompletions) => is_claude_family(model), + } + } + /// Create a SupportedApi from an endpoint path + pub fn from_endpoint(endpoint: &str) -> Option { + if let Some(openai_api) = OpenAIApi::from_endpoint(endpoint) { + return Some(SupportedApi::OpenAI(openai_api)); + } + + if let Some(anthropic_api) = AnthropicApi::from_endpoint(endpoint) { + return Some(SupportedApi::Anthropic(anthropic_api)); + } + + None + } + + /// Get the endpoint path for this API + pub fn endpoint(&self) -> &'static str { + match self { + SupportedApi::OpenAI(api) => api.endpoint(), + SupportedApi::Anthropic(api) => api.endpoint(), + } + } + + /// Get the API family name + pub fn api_family(&self) -> &'static str { + match self { + SupportedApi::OpenAI(_) => "openai", + SupportedApi::Anthropic(_) => "anthropic", + } + } + + /// Determine the target endpoint for a given provider + /// For /v1/messages: if provider is Anthropic, use /v1/messages; otherwise use /v1/chat/completions + pub fn target_endpoint_for_provider(&self, provider: &str) -> &'static str { + match self { + SupportedApi::Anthropic(AnthropicApi::Messages) => { + if provider.to_lowercase().contains("anthropic") || + provider.to_lowercase().contains("claude") { + "/v1/messages" + } else { + "/v1/chat/completions" + } + } + _ => self.endpoint() + } + } + + /// Check if request conversion is required for the given provider + /// True if we need to convert between Anthropic and OpenAI formats + pub fn requires_conversion(&self, provider: &str) -> bool { + match self { + SupportedApi::Anthropic(AnthropicApi::Messages) => { + // If provider is not Anthropic/Claude, we need to convert to OpenAI format + !(provider.to_lowercase().contains("anthropic") || + provider.to_lowercase().contains("claude")) + } + SupportedApi::OpenAI(OpenAIApi::ChatCompletions) => { + // If provider is Anthropic/Claude but request is OpenAI format, need conversion + provider.to_lowercase().contains("anthropic") || + provider.to_lowercase().contains("claude") + } + } + } +} + + + /// Get all supported endpoint paths pub fn supported_endpoints() -> Vec<&'static str> { let mut endpoints = Vec::new(); @@ -74,15 +141,15 @@ mod tests { #[test] fn test_is_supported_endpoint() { // OpenAI endpoints - assert!(is_supported_endpoint("/v1/chat/completions")); + assert!(SupportedApi::from_endpoint("/v1/chat/completions").is_some()); // Anthropic endpoints - assert!(is_supported_endpoint("/v1/messages")); + assert!(SupportedApi::from_endpoint("/v1/messages").is_some()); // Unsupported endpoints - assert!(!is_supported_endpoint("/v1/unknown")); - assert!(!is_supported_endpoint("/v2/chat")); - assert!(!is_supported_endpoint("")); + assert!(!SupportedApi::from_endpoint("/v1/unknown").is_some()); + assert!(!SupportedApi::from_endpoint("/v2/chat").is_some()); + assert!(!SupportedApi::from_endpoint("").is_some()); } #[test] diff --git a/crates/hermesllm/src/clients/mod.rs b/crates/hermesllm/src/clients/mod.rs index eb3032ce..91522198 100644 --- a/crates/hermesllm/src/clients/mod.rs +++ b/crates/hermesllm/src/clients/mod.rs @@ -4,6 +4,6 @@ pub mod endpoints; // Re-export the main items for easier access pub use lib::*; -pub use endpoints::{is_supported_endpoint, supported_endpoints, identify_provider}; +pub use endpoints::{SupportedApi, identify_provider}; // Note: transformer module contains TryFrom trait implementations that are automatically available diff --git a/crates/hermesllm/src/providers/adapters.rs b/crates/hermesllm/src/providers/adapters.rs index a001cf09..17b57dc5 100644 --- a/crates/hermesllm/src/providers/adapters.rs +++ b/crates/hermesllm/src/providers/adapters.rs @@ -1,9 +1,21 @@ +//! Provider adapter configuration and API compatibility utilities. +// +// Note: For all request/response conversions between Anthropic and OpenAI APIs, +// use the peer-reviewed and well-tested implementations in `clients/transformer.rs`. +// This file should not contain conversion logic. + +/// Utility to check if a model is from the Claude/Anthropic family +pub fn is_claude_family(model: &str) -> bool { + let model = model.to_lowercase(); + model.contains("claude") || model.contains("anthropic") +} use crate::providers::id::ProviderId; #[derive(Debug, Clone)] pub enum AdapterType { OpenAICompatible, - // Future: Claude, Gemini, etc. + AnthropicCompatible, + // Future: Gemini, etc. } /// Provider adapter configuration @@ -29,11 +41,17 @@ pub fn supported_apis(provider_id: &ProviderId) -> Vec<&'static str> { pub fn get_provider_config(provider_id: &ProviderId) -> ProviderConfig { match provider_id { ProviderId::OpenAI | ProviderId::Groq | ProviderId::Mistral | ProviderId::Deepseek - | ProviderId::Arch | ProviderId::Gemini | ProviderId::Claude | ProviderId::GitHub => { + | ProviderId::Arch | ProviderId::Gemini | ProviderId::GitHub => { ProviderConfig { supported_apis: &["/v1/chat/completions"], adapter_type: AdapterType::OpenAICompatible, } } + ProviderId::Claude => { + ProviderConfig { + supported_apis: &["/v1/messages", "/v1/chat/completions"], + adapter_type: AdapterType::AnthropicCompatible, + } + } } } diff --git a/crates/hermesllm/src/providers/request.rs b/crates/hermesllm/src/providers/request.rs index 1eb39416..e0d82566 100644 --- a/crates/hermesllm/src/providers/request.rs +++ b/crates/hermesllm/src/providers/request.rs @@ -1,11 +1,13 @@ use crate::apis::openai::ChatCompletionsRequest; +use crate::apis::anthropic::MessagesRequest; +use crate::clients::endpoints::SupportedApi; use super::{ProviderId, get_provider_config, AdapterType}; use std::error::Error; use std::fmt; pub enum ProviderRequestType { ChatCompletionsRequest(ChatCompletionsRequest), - //MessagesRequest(MessagesRequest), + MessagesRequest(MessagesRequest), //add more request types here } @@ -31,7 +33,42 @@ impl TryFrom<(&[u8], &ProviderId)> for ProviderRequestType { .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; Ok(ProviderRequestType::ChatCompletionsRequest(chat_completion_request)) } - // Future: handle other adapter types like Claude + AdapterType::AnthropicCompatible => { + // For Anthropic providers, try to parse as MessagesRequest first, fallback to ChatCompletionsRequest + if let Ok(messages_request) = MessagesRequest::try_from(bytes) { + Ok(ProviderRequestType::MessagesRequest(messages_request)) + } else { + let chat_completion_request: ChatCompletionsRequest = ChatCompletionsRequest::try_from(bytes) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; + Ok(ProviderRequestType::ChatCompletionsRequest(chat_completion_request)) + } + } + } + } +} + +/// Parse request based on endpoint and provider information +impl TryFrom<(&[u8], &str, &ProviderId)> for ProviderRequestType { + type Error = std::io::Error; + + fn try_from((bytes, endpoint, provider_id): (&[u8], &str, &ProviderId)) -> Result { + // Use SupportedApi to determine the appropriate request type + if let Some(api) = SupportedApi::from_endpoint(endpoint) { + match api { + SupportedApi::OpenAI(_) => { + let chat_completion_request: ChatCompletionsRequest = ChatCompletionsRequest::try_from(bytes) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; + Ok(ProviderRequestType::ChatCompletionsRequest(chat_completion_request)) + } + SupportedApi::Anthropic(_) => { + let messages_request: MessagesRequest = MessagesRequest::try_from(bytes) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; + Ok(ProviderRequestType::MessagesRequest(messages_request)) + } + } + } else { + // Fallback to provider-based parsing for unsupported endpoints + Self::try_from((bytes, provider_id)) } } } @@ -60,36 +97,42 @@ impl ProviderRequest for ProviderRequestType { fn model(&self) -> &str { match self { Self::ChatCompletionsRequest(r) => r.model(), + Self::MessagesRequest(r) => r.model(), } } fn set_model(&mut self, model: String) { match self { Self::ChatCompletionsRequest(r) => r.set_model(model), + Self::MessagesRequest(r) => r.set_model(model), } } fn is_streaming(&self) -> bool { match self { Self::ChatCompletionsRequest(r) => r.is_streaming(), + Self::MessagesRequest(r) => r.is_streaming(), } } fn extract_messages_text(&self) -> String { match self { Self::ChatCompletionsRequest(r) => r.extract_messages_text(), + Self::MessagesRequest(r) => r.extract_messages_text(), } } fn get_recent_user_message(&self) -> Option { match self { Self::ChatCompletionsRequest(r) => r.get_recent_user_message(), + Self::MessagesRequest(r) => r.get_recent_user_message(), } } fn to_bytes(&self) -> Result, ProviderRequestError> { match self { Self::ChatCompletionsRequest(r) => r.to_bytes(), + Self::MessagesRequest(r) => r.to_bytes(), } } } diff --git a/crates/hermesllm/src/providers/response.rs b/crates/hermesllm/src/providers/response.rs index faca303f..56410540 100644 --- a/crates/hermesllm/src/providers/response.rs +++ b/crates/hermesllm/src/providers/response.rs @@ -27,7 +27,10 @@ impl TryFrom<(&[u8], ProviderId)> for ProviderResponseType { .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; Ok(ProviderResponseType::ChatCompletionsResponse(chat_completions_response)) } - // Future: handle other adapter types like Claude + AdapterType::AnthropicCompatible => { + // TODO: Implement MessagesResponse parsing for Anthropic-compatible providers + todo!("AnthropicCompatible response parsing not yet implemented"); + } } } } @@ -49,11 +52,10 @@ impl TryFrom<(&[u8], &ProviderId)> for ProviderStreamResponseIter { let iter = crate::apis::openai::OpenAISseIter::new(sse_container); Ok(ProviderStreamResponseIter::ChatCompletionsStream(iter)) } - // Future: AdapterType::Claude => { - // let sse_container = SseStreamIter::new(lines.into_iter()); - // let iter = crate::apis::anthropic::AnthropicSseIter::new(sse_container); - // Ok(ProviderStreamResponseIter::MessagesStream(iter)) - // } + AdapterType::AnthropicCompatible => { + // TODO: Implement Anthropic streaming support + todo!("AnthropicCompatible streaming not yet implemented"); + } } } } diff --git a/crates/llm_gateway/src/stream_context.rs b/crates/llm_gateway/src/stream_context.rs index 6b2c5f15..cc769c7a 100644 --- a/crates/llm_gateway/src/stream_context.rs +++ b/crates/llm_gateway/src/stream_context.rs @@ -1,8 +1,8 @@ use crate::metrics::Metrics; use common::configuration::{LlmProvider, LlmProviderType, Overrides}; use common::consts::{ - ARCH_PROVIDER_HINT_HEADER, ARCH_ROUTING_HEADER, CHAT_COMPLETIONS_PATH, HEALTHZ_PATH, - RATELIMIT_SELECTOR_HEADER_KEY, REQUEST_ID_HEADER, TRACE_PARENT_HEADER, + ARCH_PROVIDER_HINT_HEADER, ARCH_ROUTING_HEADER, HEALTHZ_PATH, RATELIMIT_SELECTOR_HEADER_KEY, + REQUEST_ID_HEADER, TRACE_PARENT_HEADER, }; use common::errors::ServerError; use common::llm_providers::LlmProviders; @@ -10,6 +10,7 @@ use common::ratelimit::Header; use common::stats::{IncrementingMetric, RecordingMetric}; use common::tracing::{Event, Span, TraceData, Traceparent}; use common::{ratelimit, routing, tokenizer}; +use hermesllm::clients::endpoints::SupportedApi; use hermesllm::providers::response::ProviderStreamResponseIter; use hermesllm::{ ProviderId, ProviderRequest, ProviderRequestType, ProviderResponse, ProviderResponseType, @@ -31,7 +32,7 @@ pub struct StreamContext { ratelimit_selector: Option
, streaming_response: bool, response_tokens: usize, - is_chat_completions_request: bool, + supported_api: Option, llm_providers: Rc, llm_provider: Option>, request_id: Option, @@ -60,7 +61,7 @@ impl StreamContext { ratelimit_selector: None, streaming_response: false, response_tokens: 0, - is_chat_completions_request: false, + supported_api: None, llm_providers, llm_provider: None, request_id: None, @@ -212,7 +213,15 @@ impl HttpContext for StreamContext { return Action::Continue; } - self.is_chat_completions_request = CHAT_COMPLETIONS_PATH == request_path; + // Check if this is a supported API endpoint + if SupportedApi::from_endpoint(&request_path).is_none() { + self.send_http_response(404, vec![], Some(b"Unsupported endpoint")); + return Action::Continue; + } + + // Get the SupportedApi for routing decisions + let supported_api = SupportedApi::from_endpoint(&request_path); + self.supported_api = supported_api; let use_agent_orchestrator = match self.overrides.as_ref() { Some(overrides) => overrides.use_agent_orchestrator.unwrap_or_default(), @@ -257,6 +266,39 @@ impl HttpContext for StreamContext { self.delete_content_length_header(); self.save_ratelimit_header(); + // Apply provider-specific path routing + match self.llm_provider.as_ref().unwrap().provider_interface { + LlmProviderType::Groq => { + if let Some(path) = self.get_http_request_header(":path") { + if path.starts_with("/v1/") { + let new_path = format!("/openai{}", path); + self.set_http_request_header(":path", Some(new_path.as_str())); + } + } + } + LlmProviderType::Gemini => { + if let Some(path) = self.get_http_request_header(":path") { + if path == "/v1/chat/completions" { + self.set_http_request_header( + ":path", + Some("/v1beta/openai/chat/completions"), + ); + } + } + } + _ => { + // Use SupportedApi for endpoint routing + if let Some(api) = &self.supported_api { + let provider_name = &self.llm_provider.as_ref().unwrap().name; + let target_endpoint = api.target_endpoint_for_provider(provider_name); + // Only update path if it's different from the original + if target_endpoint != request_path { + self.set_http_request_header(":path", Some(target_endpoint)); + } + } + } + } + self.request_id = self.get_http_request_header(REQUEST_ID_HEADER); self.traceparent = self.get_http_request_header(TRACE_PARENT_HEADER); @@ -299,22 +341,26 @@ impl HttpContext for StreamContext { }; let provider_id = self.get_provider_id(); + let request_path = self.get_http_request_header(":path").unwrap_or_default(); - let mut deserialized_body = - match ProviderRequestType::try_from((&body_bytes[..], &provider_id)) { - Ok(deserialized) => deserialized, - Err(e) => { - debug!( - "on_http_request_body: request body: {}", - String::from_utf8_lossy(&body_bytes) - ); - self.send_server_error( - ServerError::LogicError(format!("Request parsing error: {}", e)), - Some(StatusCode::BAD_REQUEST), - ); - return Action::Pause; - } - }; + let mut deserialized_body = match ProviderRequestType::try_from(( + &body_bytes[..], + request_path.as_str(), + &provider_id, + )) { + Ok(deserialized) => deserialized, + Err(e) => { + debug!( + "on_http_request_body: request body: {}", + String::from_utf8_lossy(&body_bytes) + ); + self.send_server_error( + ServerError::LogicError(format!("Request parsing error: {}", e)), + Some(StatusCode::BAD_REQUEST), + ); + return Action::Pause; + } + }; let model_name = match self.llm_provider.as_ref() { Some(llm_provider) => llm_provider.model.as_ref(), @@ -423,9 +469,12 @@ impl HttpContext for StreamContext { return Action::Continue; } - if !self.is_chat_completions_request { - info!("on_http_response_body: non-chatcompletion request"); - return Action::Continue; + match self.supported_api { + Some(SupportedApi::OpenAI(_)) => {} + _ => { + info!("on_http_response_body: non-chatcompletion request"); + return Action::Continue; + } } let current_time = get_current_time().unwrap();