diff --git a/crates/hermesllm/src/clients/endpoints.rs b/crates/hermesllm/src/clients/endpoints.rs index 7b6bb3fe..ca751fbd 100644 --- a/crates/hermesllm/src/clients/endpoints.rs +++ b/crates/hermesllm/src/clients/endpoints.rs @@ -21,7 +21,7 @@ //! assert!(endpoints.contains(&"/v1/messages")); //! ``` -use crate::apis::{AnthropicApi, OpenAIApi, ApiDefinition}; +use crate::{apis::{AnthropicApi, ApiDefinition, OpenAIApi}, ProviderId}; /// Unified enum representing all supported API endpoints across providers #[derive(Debug, Clone, PartialEq)] @@ -52,18 +52,34 @@ impl SupportedAPIs { } } - //TODO: we need to clean this up. Why do we need this in the first place? - pub fn target_endpoint_for_provider(&self, provider: &str) -> &'static str { + pub fn target_endpoint_for_provider(&self, provider_id: &ProviderId, request_path: &str) -> String { + let default_endpoint = "/v1/chat/completions".to_string(); match self { SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages) => { - if provider.to_lowercase().contains("anthropic") || - provider.to_lowercase().contains("claude") { - "/v1/messages" - } else { - "/v1/chat/completions" + match provider_id { + ProviderId::Claude => "/v1/messages".to_string(), + _ => default_endpoint, + } + } + _ => { + match provider_id { + ProviderId::Groq => { + if request_path.starts_with("/v1/") { + format!("/openai{}", request_path) + } else { + default_endpoint + } + } + ProviderId::Gemini => { + if request_path.starts_with("/v1/") { + "/v1beta/openai/chat/completions".to_string() + } else { + default_endpoint + } + } + _ => default_endpoint, } } - _ => self.endpoint() } } } diff --git a/crates/hermesllm/src/providers/response.rs b/crates/hermesllm/src/providers/response.rs index b1d387e3..610d16eb 100644 --- a/crates/hermesllm/src/providers/response.rs +++ b/crates/hermesllm/src/providers/response.rs @@ -12,6 +12,7 @@ use std::convert::TryFrom; use crate::apis::anthropic::MessagesResponse; #[derive(Serialize)] +#[serde(untagged)] pub enum ProviderResponseType { ChatCompletionsResponse(ChatCompletionsResponse), MessagesResponse(MessagesResponse), @@ -104,19 +105,6 @@ impl Iterator for ProviderStreamResponseIter { } } } - -// Helper to serialize only the inner struct, not the enum wrapper. -// This avoids the problem where serde serializes the enum variant as a wrapper object in JSON. -impl ProviderResponseType { - /// Serialize the response as JSON bytes, omitting the enum wrapper. - pub fn as_json_bytes(&self) -> Result, serde_json::Error> { - match self { - ProviderResponseType::ChatCompletionsResponse(resp) => serde_json::to_vec(resp), - ProviderResponseType::MessagesResponse(resp) => serde_json::to_vec(resp), - } - } -} - pub trait ProviderResponse: Send + Sync { /// Get usage information if available - returns dynamic trait object fn usage(&self) -> Option<&dyn TokenUsage>; diff --git a/crates/llm_gateway/src/stream_context.rs b/crates/llm_gateway/src/stream_context.rs index b41396fd..fbd7faf5 100644 --- a/crates/llm_gateway/src/stream_context.rs +++ b/crates/llm_gateway/src/stream_context.rs @@ -11,7 +11,7 @@ use common::stats::{IncrementingMetric, RecordingMetric}; use common::tracing::{Event, Span, TraceData, Traceparent}; use common::{ratelimit, routing, tokenizer}; use hermesllm::clients::endpoints::SupportedAPIs; -use hermesllm::providers::response::ProviderStreamResponseIter; +use hermesllm::providers::response::{ProviderResponse, ProviderStreamResponseIter}; use hermesllm::{ProviderId, ProviderRequest, ProviderRequestType, ProviderResponseType}; use http::StatusCode; use log::{debug, info, warn}; @@ -85,6 +85,18 @@ impl StreamContext { self.llm_provider().to_provider_id() } + //This function assumes that the provider has been set. + fn update_upstream_path(&mut self, request_path: &str) { + let hermes_provider_id = self.llm_provider().to_provider_id(); + if let Some(api) = &self.client_api { + let target_endpoint = + api.target_endpoint_for_provider(&hermes_provider_id, request_path); + if target_endpoint != request_path { + self.set_http_request_header(":path", Some(&target_endpoint)); + } + } + } + fn select_llm_provider(&mut self) { let provider_hint = self .get_http_request_header(ARCH_PROVIDER_HINT_HEADER) @@ -95,28 +107,6 @@ impl StreamContext { provider_hint, )); - match self.llm_provider.as_ref().unwrap().provider_interface { - LlmProviderType::Groq => { - if let Some(path) = self.get_http_request_header(":path") { - if path.starts_with("/v1/") { - let new_path = format!("/openai{}", path); - self.set_http_request_header(":path", Some(new_path.as_str())); - } - } - } - LlmProviderType::Gemini => { - if let Some(path) = self.get_http_request_header(":path") { - if path == "/v1/chat/completions" { - self.set_http_request_header( - ":path", - Some("/v1beta/openai/chat/completions"), - ); - } - } - } - _ => {} - } - debug!( "request received: llm provider hint: {}, selected provider: {}", self.get_http_request_header(ARCH_PROVIDER_HINT_HEADER) @@ -227,10 +217,35 @@ impl HttpContext for StreamContext { self.llm_provider = Some(Rc::new(LlmProvider { name: routing_header_value.to_string(), provider_interface: LlmProviderType::OpenAI, - ..Default::default() + ..Default::default() //TODO: THiS IS BROKEN. WHY ARE WE ASSUMING OPENAI FOR UPSTREAM? })); } else { + //TODO: Fix this brittle code path. We need to return values and have compile time self.select_llm_provider(); + + // Check if this is a supported API endpoint + if SupportedAPIs::from_endpoint(&request_path).is_none() { + self.send_http_response(404, vec![], Some(b"Unsupported endpoint")); + return Action::Continue; + } + + // Get the SupportedApi for routing decisions + let supported_api: Option = SupportedAPIs::from_endpoint(&request_path); + self.client_api = supported_api; + + // Debug: log provider, client API, resolved API, and request path + if let (Some(api), Some(provider)) = + (self.client_api.as_ref(), self.llm_provider.as_ref()) + { + let provider_id = provider.to_provider_id(); + self.resolved_api = Some(provider_id.compatible_api_for_client(api)); + } else { + self.resolved_api = None; + } + + //We need to update the upstream path if there is a variation for a provider like Gemini/Groq, etc. + self.update_upstream_path(&request_path); + if self.llm_provider().endpoint.is_some() { self.add_http_request_header( ARCH_ROUTING_HEADER, @@ -257,62 +272,9 @@ impl HttpContext for StreamContext { self.delete_content_length_header(); self.save_ratelimit_header(); - // Apply provider-specific path routing - match self.llm_provider.as_ref().unwrap().provider_interface { - LlmProviderType::Groq => { - if let Some(path) = self.get_http_request_header(":path") { - if path.starts_with("/v1/") { - let new_path = format!("/openai{}", path); - self.set_http_request_header(":path", Some(new_path.as_str())); - } - } - } - LlmProviderType::Gemini => { - if let Some(path) = self.get_http_request_header(":path") { - if path == "/v1/chat/completions" { - self.set_http_request_header( - ":path", - Some("/v1beta/openai/chat/completions"), - ); - } - } - } - _ => { - // Use SupportedApi for endpoint routing - if let Some(api) = &self.client_api { - let provider_name = &self.llm_provider.as_ref().unwrap().name; - let target_endpoint = api.target_endpoint_for_provider(provider_name); - // Only update path if it's different from the original - if target_endpoint != request_path { - self.set_http_request_header(":path", Some(target_endpoint)); - } - } - } - } - self.request_id = self.get_http_request_header(REQUEST_ID_HEADER); self.traceparent = self.get_http_request_header(TRACE_PARENT_HEADER); - // Check if this is a supported API endpoint - if SupportedAPIs::from_endpoint(&request_path).is_none() { - self.send_http_response(404, vec![], Some(b"Unsupported endpoint")); - return Action::Continue; - } - - // Get the SupportedApi for routing decisions - let supported_api: Option = SupportedAPIs::from_endpoint(&request_path); - self.client_api = supported_api; - - // Debug: log provider, client API, resolved API, and request path - if let (Some(api), Some(provider)) = (self.client_api.as_ref(), self.llm_provider.as_ref()) - { - let provider_id = provider.to_provider_id(); - let resolved_api = provider_id.compatible_api_for_client(api); - self.resolved_api = Some(resolved_api); - } else { - self.resolved_api = None; - } - Action::Continue } @@ -678,30 +640,24 @@ impl HttpContext for StreamContext { } } else { debug!("non streaming response"); - match (supported_api, self.resolved_api.as_ref()) { + let provider_id = self.get_provider_id(); + let supported_api = self.client_api.as_ref(); + + let response: ProviderResponseType = match (supported_api, self.resolved_api.as_ref()) { (Some(supported_api), Some(_)) => { match ProviderResponseType::try_from((&body[..], supported_api, &provider_id)) { - Ok(response) => match response.as_json_bytes() { - Ok(bytes) => { - self.set_http_response_body(0, bytes.len(), &bytes); - } - Err(e) => { - self.send_server_error( - ServerError::LogicError(format!( - "Response serialization error: {}", - e - )), - Some(StatusCode::BAD_REQUEST), - ); - return Action::Continue; - } - }, + Ok(response) => response, Err(e) => { warn!( "could not parse response: {}, body str: {}", e, String::from_utf8_lossy(&body) ); + debug!( + "on_http_response_body: S[{}], response body: {}", + self.context_id, + String::from_utf8_lossy(&body) + ); self.send_server_error( ServerError::LogicError(format!("Response parsing error: {}", e)), Some(StatusCode::BAD_REQUEST), @@ -712,7 +668,21 @@ impl HttpContext for StreamContext { } _ => { warn!("Missing supported_api or resolved_api for non-streaming response"); + return Action::Continue; } + }; + + // Use provider interface to extract usage information + if let Some((prompt_tokens, completion_tokens, total_tokens)) = + response.extract_usage_counts() + { + debug!( + "Response usage: prompt={}, completion={}, total={}", + prompt_tokens, completion_tokens, total_tokens + ); + self.response_tokens = completion_tokens; + } else { + warn!("No usage information found in response"); } } diff --git a/crates/prompt_gateway/src/stream_context.rs b/crates/prompt_gateway/src/stream_context.rs index 768e6780..67eca202 100644 --- a/crates/prompt_gateway/src/stream_context.rs +++ b/crates/prompt_gateway/src/stream_context.rs @@ -665,9 +665,6 @@ impl StreamContext { } pub fn default_target_handler(&self, body: Vec, mut callout_context: StreamCallContext) { - // Debug: print raw bytes in hex to diagnose extra data - debug!("raw upstream response bytes (hex): {}", - body.iter().map(|b| format!("{:02x}", b)).collect::>().join(" ")); let prompt_target = self .prompt_targets .get(callout_context.prompt_target_name.as_ref().unwrap())