fixed issues with non-streaming responses

This commit is contained in:
Salman Paracha 2025-08-24 18:52:48 -07:00
parent 77491b4a69
commit 9f6d2464f6
4 changed files with 89 additions and 118 deletions

View file

@ -21,7 +21,7 @@
//! assert!(endpoints.contains(&"/v1/messages"));
//! ```
use crate::apis::{AnthropicApi, OpenAIApi, ApiDefinition};
use crate::{apis::{AnthropicApi, ApiDefinition, OpenAIApi}, ProviderId};
/// Unified enum representing all supported API endpoints across providers
#[derive(Debug, Clone, PartialEq)]
@ -52,18 +52,34 @@ impl SupportedAPIs {
}
}
//TODO: we need to clean this up. Why do we need this in the first place?
pub fn target_endpoint_for_provider(&self, provider: &str) -> &'static str {
pub fn target_endpoint_for_provider(&self, provider_id: &ProviderId, request_path: &str) -> String {
let default_endpoint = "/v1/chat/completions".to_string();
match self {
SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages) => {
if provider.to_lowercase().contains("anthropic") ||
provider.to_lowercase().contains("claude") {
"/v1/messages"
} else {
"/v1/chat/completions"
match provider_id {
ProviderId::Claude => "/v1/messages".to_string(),
_ => default_endpoint,
}
}
_ => {
match provider_id {
ProviderId::Groq => {
if request_path.starts_with("/v1/") {
format!("/openai{}", request_path)
} else {
default_endpoint
}
}
ProviderId::Gemini => {
if request_path.starts_with("/v1/") {
"/v1beta/openai/chat/completions".to_string()
} else {
default_endpoint
}
}
_ => default_endpoint,
}
}
_ => self.endpoint()
}
}
}

View file

@ -12,6 +12,7 @@ use std::convert::TryFrom;
use crate::apis::anthropic::MessagesResponse;
#[derive(Serialize)]
#[serde(untagged)]
pub enum ProviderResponseType {
ChatCompletionsResponse(ChatCompletionsResponse),
MessagesResponse(MessagesResponse),
@ -104,19 +105,6 @@ impl Iterator for ProviderStreamResponseIter {
}
}
}
// Helper to serialize only the inner struct, not the enum wrapper.
// This avoids the problem where serde serializes the enum variant as a wrapper object in JSON.
impl ProviderResponseType {
/// Serialize the response as JSON bytes, omitting the enum wrapper.
pub fn as_json_bytes(&self) -> Result<Vec<u8>, serde_json::Error> {
match self {
ProviderResponseType::ChatCompletionsResponse(resp) => serde_json::to_vec(resp),
ProviderResponseType::MessagesResponse(resp) => serde_json::to_vec(resp),
}
}
}
pub trait ProviderResponse: Send + Sync {
/// Get usage information if available - returns dynamic trait object
fn usage(&self) -> Option<&dyn TokenUsage>;

View file

@ -11,7 +11,7 @@ use common::stats::{IncrementingMetric, RecordingMetric};
use common::tracing::{Event, Span, TraceData, Traceparent};
use common::{ratelimit, routing, tokenizer};
use hermesllm::clients::endpoints::SupportedAPIs;
use hermesllm::providers::response::ProviderStreamResponseIter;
use hermesllm::providers::response::{ProviderResponse, ProviderStreamResponseIter};
use hermesllm::{ProviderId, ProviderRequest, ProviderRequestType, ProviderResponseType};
use http::StatusCode;
use log::{debug, info, warn};
@ -85,6 +85,18 @@ impl StreamContext {
self.llm_provider().to_provider_id()
}
//This function assumes that the provider has been set.
fn update_upstream_path(&mut self, request_path: &str) {
let hermes_provider_id = self.llm_provider().to_provider_id();
if let Some(api) = &self.client_api {
let target_endpoint =
api.target_endpoint_for_provider(&hermes_provider_id, request_path);
if target_endpoint != request_path {
self.set_http_request_header(":path", Some(&target_endpoint));
}
}
}
fn select_llm_provider(&mut self) {
let provider_hint = self
.get_http_request_header(ARCH_PROVIDER_HINT_HEADER)
@ -95,28 +107,6 @@ impl StreamContext {
provider_hint,
));
match self.llm_provider.as_ref().unwrap().provider_interface {
LlmProviderType::Groq => {
if let Some(path) = self.get_http_request_header(":path") {
if path.starts_with("/v1/") {
let new_path = format!("/openai{}", path);
self.set_http_request_header(":path", Some(new_path.as_str()));
}
}
}
LlmProviderType::Gemini => {
if let Some(path) = self.get_http_request_header(":path") {
if path == "/v1/chat/completions" {
self.set_http_request_header(
":path",
Some("/v1beta/openai/chat/completions"),
);
}
}
}
_ => {}
}
debug!(
"request received: llm provider hint: {}, selected provider: {}",
self.get_http_request_header(ARCH_PROVIDER_HINT_HEADER)
@ -227,10 +217,35 @@ impl HttpContext for StreamContext {
self.llm_provider = Some(Rc::new(LlmProvider {
name: routing_header_value.to_string(),
provider_interface: LlmProviderType::OpenAI,
..Default::default()
..Default::default() //TODO: THiS IS BROKEN. WHY ARE WE ASSUMING OPENAI FOR UPSTREAM?
}));
} else {
//TODO: Fix this brittle code path. We need to return values and have compile time
self.select_llm_provider();
// Check if this is a supported API endpoint
if SupportedAPIs::from_endpoint(&request_path).is_none() {
self.send_http_response(404, vec![], Some(b"Unsupported endpoint"));
return Action::Continue;
}
// Get the SupportedApi for routing decisions
let supported_api: Option<SupportedAPIs> = SupportedAPIs::from_endpoint(&request_path);
self.client_api = supported_api;
// Debug: log provider, client API, resolved API, and request path
if let (Some(api), Some(provider)) =
(self.client_api.as_ref(), self.llm_provider.as_ref())
{
let provider_id = provider.to_provider_id();
self.resolved_api = Some(provider_id.compatible_api_for_client(api));
} else {
self.resolved_api = None;
}
//We need to update the upstream path if there is a variation for a provider like Gemini/Groq, etc.
self.update_upstream_path(&request_path);
if self.llm_provider().endpoint.is_some() {
self.add_http_request_header(
ARCH_ROUTING_HEADER,
@ -257,62 +272,9 @@ impl HttpContext for StreamContext {
self.delete_content_length_header();
self.save_ratelimit_header();
// Apply provider-specific path routing
match self.llm_provider.as_ref().unwrap().provider_interface {
LlmProviderType::Groq => {
if let Some(path) = self.get_http_request_header(":path") {
if path.starts_with("/v1/") {
let new_path = format!("/openai{}", path);
self.set_http_request_header(":path", Some(new_path.as_str()));
}
}
}
LlmProviderType::Gemini => {
if let Some(path) = self.get_http_request_header(":path") {
if path == "/v1/chat/completions" {
self.set_http_request_header(
":path",
Some("/v1beta/openai/chat/completions"),
);
}
}
}
_ => {
// Use SupportedApi for endpoint routing
if let Some(api) = &self.client_api {
let provider_name = &self.llm_provider.as_ref().unwrap().name;
let target_endpoint = api.target_endpoint_for_provider(provider_name);
// Only update path if it's different from the original
if target_endpoint != request_path {
self.set_http_request_header(":path", Some(target_endpoint));
}
}
}
}
self.request_id = self.get_http_request_header(REQUEST_ID_HEADER);
self.traceparent = self.get_http_request_header(TRACE_PARENT_HEADER);
// Check if this is a supported API endpoint
if SupportedAPIs::from_endpoint(&request_path).is_none() {
self.send_http_response(404, vec![], Some(b"Unsupported endpoint"));
return Action::Continue;
}
// Get the SupportedApi for routing decisions
let supported_api: Option<SupportedAPIs> = SupportedAPIs::from_endpoint(&request_path);
self.client_api = supported_api;
// Debug: log provider, client API, resolved API, and request path
if let (Some(api), Some(provider)) = (self.client_api.as_ref(), self.llm_provider.as_ref())
{
let provider_id = provider.to_provider_id();
let resolved_api = provider_id.compatible_api_for_client(api);
self.resolved_api = Some(resolved_api);
} else {
self.resolved_api = None;
}
Action::Continue
}
@ -678,30 +640,24 @@ impl HttpContext for StreamContext {
}
} else {
debug!("non streaming response");
match (supported_api, self.resolved_api.as_ref()) {
let provider_id = self.get_provider_id();
let supported_api = self.client_api.as_ref();
let response: ProviderResponseType = match (supported_api, self.resolved_api.as_ref()) {
(Some(supported_api), Some(_)) => {
match ProviderResponseType::try_from((&body[..], supported_api, &provider_id)) {
Ok(response) => match response.as_json_bytes() {
Ok(bytes) => {
self.set_http_response_body(0, bytes.len(), &bytes);
}
Err(e) => {
self.send_server_error(
ServerError::LogicError(format!(
"Response serialization error: {}",
e
)),
Some(StatusCode::BAD_REQUEST),
);
return Action::Continue;
}
},
Ok(response) => response,
Err(e) => {
warn!(
"could not parse response: {}, body str: {}",
e,
String::from_utf8_lossy(&body)
);
debug!(
"on_http_response_body: S[{}], response body: {}",
self.context_id,
String::from_utf8_lossy(&body)
);
self.send_server_error(
ServerError::LogicError(format!("Response parsing error: {}", e)),
Some(StatusCode::BAD_REQUEST),
@ -712,7 +668,21 @@ impl HttpContext for StreamContext {
}
_ => {
warn!("Missing supported_api or resolved_api for non-streaming response");
return Action::Continue;
}
};
// Use provider interface to extract usage information
if let Some((prompt_tokens, completion_tokens, total_tokens)) =
response.extract_usage_counts()
{
debug!(
"Response usage: prompt={}, completion={}, total={}",
prompt_tokens, completion_tokens, total_tokens
);
self.response_tokens = completion_tokens;
} else {
warn!("No usage information found in response");
}
}

View file

@ -665,9 +665,6 @@ impl StreamContext {
}
pub fn default_target_handler(&self, body: Vec<u8>, mut callout_context: StreamCallContext) {
// Debug: print raw bytes in hex to diagnose extra data
debug!("raw upstream response bytes (hex): {}",
body.iter().map(|b| format!("{:02x}", b)).collect::<Vec<_>>().join(" "));
let prompt_target = self
.prompt_targets
.get(callout_context.prompt_target_name.as_ref().unwrap())