fixed bugs for integration tests

This commit is contained in:
Salman Paracha 2025-08-23 16:37:52 -07:00
parent 9f3a6f71a3
commit 7345657612
8 changed files with 82 additions and 80 deletions

View file

@ -33,6 +33,8 @@ pub fn get_llm_provider(
return provider;
}
//This is a fallback to the default provider if no specific provider is found.
//For example, if the client sends in gpt-4-1 and that's not configured in arch_config, we fall back to the default.
if llm_providers.default().is_some() {
return llm_providers.default().unwrap();
}

View file

@ -25,20 +25,20 @@ use crate::apis::{AnthropicApi, OpenAIApi, ApiDefinition};
/// Unified enum representing all supported API endpoints across providers
#[derive(Debug, Clone, PartialEq)]
pub enum SupportedApi {
OpenAI(OpenAIApi),
Anthropic(AnthropicApi),
pub enum SupportedAPIs {
OpenAIChatCompletions(OpenAIApi),
AnthropicMessagesAPI(AnthropicApi),
}
impl SupportedApi {
impl SupportedAPIs {
/// Create a SupportedApi from an endpoint path
pub fn from_endpoint(endpoint: &str) -> Option<Self> {
if let Some(openai_api) = OpenAIApi::from_endpoint(endpoint) {
return Some(SupportedApi::OpenAI(openai_api));
return Some(SupportedAPIs::OpenAIChatCompletions(openai_api));
}
if let Some(anthropic_api) = AnthropicApi::from_endpoint(endpoint) {
return Some(SupportedApi::Anthropic(anthropic_api));
return Some(SupportedAPIs::AnthropicMessagesAPI(anthropic_api));
}
None
@ -47,16 +47,15 @@ impl SupportedApi {
/// Get the endpoint path for this API
pub fn endpoint(&self) -> &'static str {
match self {
SupportedApi::OpenAI(api) => api.endpoint(),
SupportedApi::Anthropic(api) => api.endpoint(),
SupportedAPIs::OpenAIChatCompletions(api) => api.endpoint(),
SupportedAPIs::AnthropicMessagesAPI(api) => api.endpoint(),
}
}
/// Determine the target endpoint for a given provider
/// For /v1/messages: if provider is Anthropic, use /v1/messages; otherwise use /v1/chat/completions
//TODO: we need to clean this up. Why do we need this in the first place?
pub fn target_endpoint_for_provider(&self, provider: &str) -> &'static str {
match self {
SupportedApi::Anthropic(AnthropicApi::Messages) => {
SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages) => {
if provider.to_lowercase().contains("anthropic") ||
provider.to_lowercase().contains("claude") {
"/v1/messages"
@ -108,15 +107,15 @@ mod tests {
#[test]
fn test_is_supported_endpoint() {
// OpenAI endpoints
assert!(SupportedApi::from_endpoint("/v1/chat/completions").is_some());
assert!(SupportedAPIs::from_endpoint("/v1/chat/completions").is_some());
// Anthropic endpoints
assert!(SupportedApi::from_endpoint("/v1/messages").is_some());
assert!(SupportedAPIs::from_endpoint("/v1/messages").is_some());
// Unsupported endpoints
assert!(!SupportedApi::from_endpoint("/v1/unknown").is_some());
assert!(!SupportedApi::from_endpoint("/v2/chat").is_some());
assert!(!SupportedApi::from_endpoint("").is_some());
assert!(!SupportedAPIs::from_endpoint("/v1/unknown").is_some());
assert!(!SupportedAPIs::from_endpoint("/v2/chat").is_some());
assert!(!SupportedAPIs::from_endpoint("").is_some());
}
#[test]

View file

@ -4,6 +4,6 @@ pub mod endpoints;
// Re-export the main items for easier access
pub use lib::*;
pub use endpoints::{SupportedApi, identify_provider};
pub use endpoints::{SupportedAPIs, identify_provider};
// Note: transformer module contains TryFrom trait implementations that are automatically available

View file

@ -71,8 +71,8 @@ mod tests {
data: [DONE]
"#;
use crate::clients::endpoints::SupportedApi;
let api = SupportedApi::OpenAI(crate::apis::OpenAIApi::ChatCompletions);
use crate::clients::endpoints::SupportedAPIs;
let api = SupportedAPIs::OpenAIChatCompletions(crate::apis::OpenAIApi::ChatCompletions);
let result = ProviderStreamResponseIter::try_from((sse_data.as_bytes(), &api, &ProviderId::OpenAI));
assert!(result.is_ok());

View file

@ -1,5 +1,5 @@
use std::fmt::Display;
use crate::clients::endpoints::SupportedApi;
use crate::clients::endpoints::SupportedAPIs;
use crate::apis::{OpenAIApi, AnthropicApi};
/// Provider identifier enum - simple enum for identifying providers
@ -33,16 +33,16 @@ impl From<&str> for ProviderId {
impl ProviderId {
/// Given a client API, return the compatible upstream API for this provider
pub fn compatible_api_for_client(&self, client_api: &SupportedApi) -> SupportedApi {
pub fn compatible_api_for_client(&self, client_api: &SupportedAPIs) -> SupportedAPIs {
match (self, client_api) {
// Claude/Anthropic providers natively support Anthropic APIs
(ProviderId::Claude, SupportedApi::Anthropic(_)) => client_api.clone(),
(ProviderId::Claude, SupportedAPIs::AnthropicMessagesAPI(_)) => client_api.clone(),
// Claude/Anthropic providers can also support OpenAI chat completions by mapping to Anthropic Messages
(ProviderId::Claude, SupportedApi::OpenAI(OpenAIApi::ChatCompletions)) => SupportedApi::Anthropic(AnthropicApi::Messages),
(ProviderId::Claude, SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions)) => SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages),
// OpenAI-compatible providers only support OpenAI chat completions
(ProviderId::OpenAI | ProviderId::Groq | ProviderId::Mistral | ProviderId::Deepseek | ProviderId::Arch | ProviderId::Gemini | ProviderId::GitHub, SupportedApi::Anthropic(_)) => SupportedApi::OpenAI(OpenAIApi::ChatCompletions),
(ProviderId::OpenAI | ProviderId::Groq | ProviderId::Mistral | ProviderId::Deepseek | ProviderId::Arch | ProviderId::Gemini | ProviderId::GitHub, SupportedApi::OpenAI(_)) => SupportedApi::OpenAI(OpenAIApi::ChatCompletions),
(ProviderId::OpenAI | ProviderId::Groq | ProviderId::Mistral | ProviderId::Deepseek | ProviderId::Arch | ProviderId::Gemini | ProviderId::GitHub, SupportedAPIs::AnthropicMessagesAPI(_)) => SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
(ProviderId::OpenAI | ProviderId::Groq | ProviderId::Mistral | ProviderId::Deepseek | ProviderId::Arch | ProviderId::Gemini | ProviderId::GitHub, SupportedAPIs::OpenAIChatCompletions(_)) => SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
}
}
}

View file

@ -1,6 +1,6 @@
use crate::apis::openai::ChatCompletionsRequest;
use crate::apis::anthropic::MessagesRequest;
use crate::clients::endpoints::SupportedApi;
use crate::clients::endpoints::SupportedAPIs;
use std::error::Error;
use std::fmt;
pub enum ProviderRequestType {
@ -21,18 +21,18 @@ impl TryFrom<&[u8]> for ProviderRequestType {
}
/// Parse request based on endpoint and provider information
impl TryFrom<(&[u8], &SupportedApi)> for ProviderRequestType {
impl TryFrom<(&[u8], &SupportedAPIs)> for ProviderRequestType {
type Error = std::io::Error;
fn try_from((bytes, endpoint): (&[u8], &SupportedApi)) -> Result<Self, Self::Error> {
fn try_from((bytes, endpoint): (&[u8], &SupportedAPIs)) -> Result<Self, Self::Error> {
// Use SupportedApi to determine the appropriate request type
match endpoint {
SupportedApi::OpenAI(_) => {
SupportedAPIs::OpenAIChatCompletions(_) => {
let chat_completion_request: ChatCompletionsRequest = ChatCompletionsRequest::try_from(bytes)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
Ok(ProviderRequestType::ChatCompletionsRequest(chat_completion_request))
}
SupportedApi::Anthropic(_) => {
SupportedAPIs::AnthropicMessagesAPI(_) => {
let messages_request: MessagesRequest = MessagesRequest::try_from(bytes)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
Ok(ProviderRequestType::MessagesRequest(messages_request))
@ -131,7 +131,7 @@ impl Error for ProviderRequestError {
#[cfg(test)]
mod tests {
use super::*;
use crate::clients::endpoints::SupportedApi;
use crate::clients::endpoints::SupportedAPIs;
use crate::apis::anthropic::AnthropicApi::Messages;
use crate::apis::openai::OpenAIApi::ChatCompletions;
use crate::apis::anthropic::MessagesRequest as AnthropicMessagesRequest;
@ -171,7 +171,7 @@ mod tests {
]
});
let bytes = serde_json::to_vec(&req).unwrap();
let endpoint = SupportedApi::Anthropic(Messages);
let endpoint = SupportedAPIs::AnthropicMessagesAPI(Messages);
let result = ProviderRequestType::try_from((bytes.as_slice(), &endpoint));
assert!(result.is_ok());
match result.unwrap() {
@ -193,7 +193,7 @@ mod tests {
]
});
let bytes = serde_json::to_vec(&req).unwrap();
let endpoint = SupportedApi::OpenAI(ChatCompletions);
let endpoint = SupportedAPIs::OpenAIChatCompletions(ChatCompletions);
let result = ProviderRequestType::try_from((bytes.as_slice(), &endpoint));
assert!(result.is_ok());
match result.unwrap() {
@ -216,7 +216,7 @@ mod tests {
});
let bytes = serde_json::to_vec(&req).unwrap();
// Intentionally use OpenAI endpoint for Anthropic payload
let endpoint = SupportedApi::OpenAI(ChatCompletions);
let endpoint = SupportedAPIs::OpenAIChatCompletions(ChatCompletions);
let result = ProviderRequestType::try_from((bytes.as_slice(), &endpoint));
// Should parse as ChatCompletionsRequest, not error
assert!(result.is_ok());

View file

@ -6,7 +6,7 @@ use std::fmt;
use crate::apis::openai::ChatCompletionsResponse;
use crate::apis::OpenAISseIter;
use crate::clients::endpoints::SupportedApi;
use crate::clients::endpoints::SupportedAPIs;
use std::convert::TryFrom;
use crate::apis::anthropic::MessagesResponse;
@ -26,28 +26,28 @@ pub enum ProviderStreamResponseIter {
// --- Response transformation logic for client API compatibility ---
impl TryFrom<(&[u8], &SupportedApi, &ProviderId)> for ProviderResponseType {
impl TryFrom<(&[u8], &SupportedAPIs, &ProviderId)> for ProviderResponseType {
type Error = std::io::Error;
fn try_from((bytes, client_api, provider_id): (&[u8], &SupportedApi, &ProviderId)) -> Result<Self, Self::Error> {
fn try_from((bytes, client_api, provider_id): (&[u8], &SupportedAPIs, &ProviderId)) -> Result<Self, Self::Error> {
let upstream_api = provider_id.compatible_api_for_client(client_api);
match (&upstream_api, client_api) {
(SupportedApi::OpenAI(_), SupportedApi::OpenAI(_)) => {
(SupportedAPIs::OpenAIChatCompletions(_), SupportedAPIs::OpenAIChatCompletions(_)) => {
let resp: ChatCompletionsResponse = ChatCompletionsResponse::try_from(bytes)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
Ok(ProviderResponseType::ChatCompletionsResponse(resp))
}
(SupportedApi::Anthropic(_), SupportedApi::Anthropic(_)) => {
(SupportedAPIs::AnthropicMessagesAPI(_), SupportedAPIs::AnthropicMessagesAPI(_)) => {
let resp: MessagesResponse = serde_json::from_slice(bytes)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
Ok(ProviderResponseType::MessagesResponse(resp))
}
(SupportedApi::OpenAI(_), SupportedApi::Anthropic(_)) => {
(SupportedAPIs::OpenAIChatCompletions(_), SupportedAPIs::AnthropicMessagesAPI(_)) => {
let resp: MessagesResponse = serde_json::from_slice(bytes)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
Ok(ProviderResponseType::MessagesResponse(resp))
}
(SupportedApi::Anthropic(_), SupportedApi::OpenAI(_)) => {
(SupportedAPIs::AnthropicMessagesAPI(_), SupportedAPIs::OpenAIChatCompletions(_)) => {
let resp: ChatCompletionsResponse = ChatCompletionsResponse::try_from(bytes)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
Ok(ProviderResponseType::ChatCompletionsResponse(resp))
@ -56,34 +56,34 @@ impl TryFrom<(&[u8], &SupportedApi, &ProviderId)> for ProviderResponseType {
}
}
impl TryFrom<(&[u8], &SupportedApi, &ProviderId)> for ProviderStreamResponseIter {
impl TryFrom<(&[u8], &SupportedAPIs, &ProviderId)> for ProviderStreamResponseIter {
type Error = Box<dyn std::error::Error + Send + Sync>;
fn try_from((bytes, client_api, provider_id): (&[u8], &SupportedApi, &ProviderId)) -> Result<Self, Self::Error> {
fn try_from((bytes, client_api, provider_id): (&[u8], &SupportedAPIs, &ProviderId)) -> Result<Self, Self::Error> {
let upstream_api = provider_id.compatible_api_for_client(client_api);
match (&upstream_api, client_api) {
(SupportedApi::OpenAI(_), SupportedApi::OpenAI(_)) => {
(SupportedAPIs::OpenAIChatCompletions(_), SupportedAPIs::OpenAIChatCompletions(_)) => {
let s = std::str::from_utf8(bytes)?;
let lines: Vec<String> = s.lines().map(|line| line.to_string()).collect();
let sse_container = crate::providers::response::SseStreamIter::new(lines.into_iter());
let iter = crate::apis::openai::OpenAISseIter::new(sse_container);
Ok(ProviderStreamResponseIter::ChatCompletionsStream(iter))
}
(SupportedApi::Anthropic(_), SupportedApi::Anthropic(_)) => {
(SupportedAPIs::AnthropicMessagesAPI(_), SupportedAPIs::AnthropicMessagesAPI(_)) => {
let s = std::str::from_utf8(bytes)?;
let lines: Vec<String> = s.lines().map(|line| line.to_string()).collect();
let sse_container = crate::providers::response::SseStreamIter::new(lines.into_iter());
let iter = crate::apis::anthropic::AnthropicSseIter::new(sse_container);
Ok(ProviderStreamResponseIter::MessagesStream(iter))
}
(SupportedApi::OpenAI(_), SupportedApi::Anthropic(_)) => {
(SupportedAPIs::OpenAIChatCompletions(_), SupportedAPIs::AnthropicMessagesAPI(_)) => {
let s = std::str::from_utf8(bytes)?;
let lines: Vec<String> = s.lines().map(|line| line.to_string()).collect();
let sse_container = crate::providers::response::SseStreamIter::new(lines.into_iter());
let iter = crate::apis::anthropic::AnthropicSseIter::new(sse_container);
Ok(ProviderStreamResponseIter::MessagesStream(iter))
}
(SupportedApi::Anthropic(_), SupportedApi::OpenAI(_)) => {
(SupportedAPIs::AnthropicMessagesAPI(_), SupportedAPIs::OpenAIChatCompletions(_)) => {
let s = std::str::from_utf8(bytes)?;
let lines: Vec<String> = s.lines().map(|line| line.to_string()).collect();
let sse_container = crate::providers::response::SseStreamIter::new(lines.into_iter());
@ -202,7 +202,7 @@ impl Error for ProviderResponseError {
#[cfg(test)]
mod tests {
use super::*;
use crate::clients::endpoints::SupportedApi;
use crate::clients::endpoints::SupportedAPIs;
use crate::providers::id::ProviderId;
use crate::apis::openai::OpenAIApi;
use crate::apis::anthropic::AnthropicApi;
@ -226,7 +226,7 @@ mod tests {
"system_fingerprint": null
});
let bytes = serde_json::to_vec(&resp).unwrap();
let result = ProviderResponseType::try_from((bytes.as_slice(), &SupportedApi::OpenAI(OpenAIApi::ChatCompletions), &ProviderId::OpenAI));
let result = ProviderResponseType::try_from((bytes.as_slice(), &SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions), &ProviderId::OpenAI));
assert!(result.is_ok());
match result.unwrap() {
ProviderResponseType::ChatCompletionsResponse(r) => {
@ -251,7 +251,7 @@ mod tests {
"usage": { "input_tokens": 10, "output_tokens": 25, "cache_creation_input_tokens": 5, "cache_read_input_tokens": 3 }
});
let bytes = serde_json::to_vec(&resp).unwrap();
let result = ProviderResponseType::try_from((bytes.as_slice(), &SupportedApi::Anthropic(AnthropicApi::Messages), &ProviderId::Claude));
let result = ProviderResponseType::try_from((bytes.as_slice(), &SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages), &ProviderId::Claude));
assert!(result.is_ok());
match result.unwrap() {
ProviderResponseType::MessagesResponse(r) => {
@ -277,7 +277,7 @@ mod tests {
"usage": { "input_tokens": 10, "output_tokens": 25, "cache_creation_input_tokens": 5, "cache_read_input_tokens": 3 }
});
let bytes = serde_json::to_vec(&resp).unwrap();
let result = ProviderResponseType::try_from((bytes.as_slice(), &SupportedApi::Anthropic(AnthropicApi::Messages), &ProviderId::OpenAI));
let result = ProviderResponseType::try_from((bytes.as_slice(), &SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages), &ProviderId::OpenAI));
assert!(result.is_ok());
match result.unwrap() {
ProviderResponseType::MessagesResponse(r) => {
@ -306,7 +306,7 @@ mod tests {
"system_fingerprint": null
});
let bytes = serde_json::to_vec(&resp).unwrap();
let result = ProviderResponseType::try_from((bytes.as_slice(), &SupportedApi::OpenAI(OpenAIApi::ChatCompletions), &ProviderId::Claude));
let result = ProviderResponseType::try_from((bytes.as_slice(), &SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions), &ProviderId::Claude));
assert!(result.is_ok());
match result.unwrap() {
ProviderResponseType::ChatCompletionsResponse(r) => {

View file

@ -10,7 +10,7 @@ use common::ratelimit::Header;
use common::stats::{IncrementingMetric, RecordingMetric};
use common::tracing::{Event, Span, TraceData, Traceparent};
use common::{ratelimit, routing, tokenizer};
use hermesllm::clients::endpoints::SupportedApi;
use hermesllm::clients::endpoints::SupportedAPIs;
use hermesllm::providers::response::ProviderStreamResponseIter;
use hermesllm::{ProviderId, ProviderRequest, ProviderRequestType, ProviderResponseType};
use http::StatusCode;
@ -30,9 +30,9 @@ pub struct StreamContext {
ratelimit_selector: Option<Header>,
streaming_response: bool,
response_tokens: usize,
supported_api: Option<SupportedApi>,
client_api: Option<SupportedAPIs>,
/// The API that should be used for the upstream provider (after compatibility mapping)
resolved_api: Option<SupportedApi>,
resolved_api: Option<SupportedAPIs>,
llm_providers: Rc<LlmProviders>,
llm_provider: Option<Rc<LlmProvider>>,
request_id: Option<String>,
@ -61,7 +61,7 @@ impl StreamContext {
ratelimit_selector: None,
streaming_response: false,
response_tokens: 0,
supported_api: None,
client_api: None,
resolved_api: None,
llm_providers,
llm_provider: None,
@ -214,26 +214,6 @@ impl HttpContext for StreamContext {
return Action::Continue;
}
// Check if this is a supported API endpoint
if SupportedApi::from_endpoint(&request_path).is_none() {
self.send_http_response(404, vec![], Some(b"Unsupported endpoint"));
return Action::Continue;
}
// Get the SupportedApi for routing decisions
let supported_api = SupportedApi::from_endpoint(&request_path);
self.supported_api = supported_api;
// Determine the resolved (upstream) API using provider compatibility
if let (Some(api), Some(provider)) =
(self.supported_api.as_ref(), self.llm_provider.as_ref())
{
let provider_id = provider.to_provider_id();
self.resolved_api = Some(provider_id.compatible_api_for_client(api));
} else {
self.resolved_api = None;
}
let use_agent_orchestrator = match self.overrides.as_ref() {
Some(overrides) => overrides.use_agent_orchestrator.unwrap_or_default(),
None => false,
@ -299,7 +279,7 @@ impl HttpContext for StreamContext {
}
_ => {
// Use SupportedApi for endpoint routing
if let Some(api) = &self.supported_api {
if let Some(api) = &self.client_api {
let provider_name = &self.llm_provider.as_ref().unwrap().name;
let target_endpoint = api.target_endpoint_for_provider(provider_name);
// Only update path if it's different from the original
@ -313,6 +293,26 @@ impl HttpContext for StreamContext {
self.request_id = self.get_http_request_header(REQUEST_ID_HEADER);
self.traceparent = self.get_http_request_header(TRACE_PARENT_HEADER);
// Check if this is a supported API endpoint
if SupportedAPIs::from_endpoint(&request_path).is_none() {
self.send_http_response(404, vec![], Some(b"Unsupported endpoint"));
return Action::Continue;
}
// Get the SupportedApi for routing decisions
let supported_api: Option<SupportedAPIs> = SupportedAPIs::from_endpoint(&request_path);
self.client_api = supported_api;
// Debug: log provider, client API, resolved API, and request path
if let (Some(api), Some(provider)) = (self.client_api.as_ref(), self.llm_provider.as_ref())
{
let provider_id = provider.to_provider_id();
let resolved_api = provider_id.compatible_api_for_client(api);
self.resolved_api = Some(resolved_api);
} else {
self.resolved_api = None;
}
Action::Continue
}
@ -484,8 +484,9 @@ impl HttpContext for StreamContext {
return Action::Continue;
}
match self.supported_api {
Some(SupportedApi::OpenAI(_)) => {}
match self.client_api {
Some(SupportedAPIs::OpenAIChatCompletions(_)) => {}
Some(SupportedAPIs::AnthropicMessagesAPI(_)) => {}
_ => {
info!("on_http_response_body: non-chatcompletion request");
return Action::Continue;
@ -619,7 +620,7 @@ impl HttpContext for StreamContext {
}
let provider_id = self.get_provider_id();
let supported_api = self.supported_api.as_ref();
let supported_api = self.client_api.as_ref();
if self.streaming_response {
debug!("processing streaming response");