mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
fixed issues with non-streaming responses
This commit is contained in:
parent
77491b4a69
commit
9f6d2464f6
4 changed files with 89 additions and 118 deletions
|
|
@ -21,7 +21,7 @@
|
|||
//! assert!(endpoints.contains(&"/v1/messages"));
|
||||
//! ```
|
||||
|
||||
use crate::apis::{AnthropicApi, OpenAIApi, ApiDefinition};
|
||||
use crate::{apis::{AnthropicApi, ApiDefinition, OpenAIApi}, ProviderId};
|
||||
|
||||
/// Unified enum representing all supported API endpoints across providers
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
|
|
@ -52,18 +52,34 @@ impl SupportedAPIs {
|
|||
}
|
||||
}
|
||||
|
||||
//TODO: we need to clean this up. Why do we need this in the first place?
|
||||
pub fn target_endpoint_for_provider(&self, provider: &str) -> &'static str {
|
||||
pub fn target_endpoint_for_provider(&self, provider_id: &ProviderId, request_path: &str) -> String {
|
||||
let default_endpoint = "/v1/chat/completions".to_string();
|
||||
match self {
|
||||
SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages) => {
|
||||
if provider.to_lowercase().contains("anthropic") ||
|
||||
provider.to_lowercase().contains("claude") {
|
||||
"/v1/messages"
|
||||
} else {
|
||||
"/v1/chat/completions"
|
||||
match provider_id {
|
||||
ProviderId::Claude => "/v1/messages".to_string(),
|
||||
_ => default_endpoint,
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
match provider_id {
|
||||
ProviderId::Groq => {
|
||||
if request_path.starts_with("/v1/") {
|
||||
format!("/openai{}", request_path)
|
||||
} else {
|
||||
default_endpoint
|
||||
}
|
||||
}
|
||||
ProviderId::Gemini => {
|
||||
if request_path.starts_with("/v1/") {
|
||||
"/v1beta/openai/chat/completions".to_string()
|
||||
} else {
|
||||
default_endpoint
|
||||
}
|
||||
}
|
||||
_ => default_endpoint,
|
||||
}
|
||||
}
|
||||
_ => self.endpoint()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ use std::convert::TryFrom;
|
|||
use crate::apis::anthropic::MessagesResponse;
|
||||
|
||||
#[derive(Serialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum ProviderResponseType {
|
||||
ChatCompletionsResponse(ChatCompletionsResponse),
|
||||
MessagesResponse(MessagesResponse),
|
||||
|
|
@ -104,19 +105,6 @@ impl Iterator for ProviderStreamResponseIter {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Helper to serialize only the inner struct, not the enum wrapper.
|
||||
// This avoids the problem where serde serializes the enum variant as a wrapper object in JSON.
|
||||
impl ProviderResponseType {
|
||||
/// Serialize the response as JSON bytes, omitting the enum wrapper.
|
||||
pub fn as_json_bytes(&self) -> Result<Vec<u8>, serde_json::Error> {
|
||||
match self {
|
||||
ProviderResponseType::ChatCompletionsResponse(resp) => serde_json::to_vec(resp),
|
||||
ProviderResponseType::MessagesResponse(resp) => serde_json::to_vec(resp),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait ProviderResponse: Send + Sync {
|
||||
/// Get usage information if available - returns dynamic trait object
|
||||
fn usage(&self) -> Option<&dyn TokenUsage>;
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ use common::stats::{IncrementingMetric, RecordingMetric};
|
|||
use common::tracing::{Event, Span, TraceData, Traceparent};
|
||||
use common::{ratelimit, routing, tokenizer};
|
||||
use hermesllm::clients::endpoints::SupportedAPIs;
|
||||
use hermesllm::providers::response::ProviderStreamResponseIter;
|
||||
use hermesllm::providers::response::{ProviderResponse, ProviderStreamResponseIter};
|
||||
use hermesllm::{ProviderId, ProviderRequest, ProviderRequestType, ProviderResponseType};
|
||||
use http::StatusCode;
|
||||
use log::{debug, info, warn};
|
||||
|
|
@ -85,6 +85,18 @@ impl StreamContext {
|
|||
self.llm_provider().to_provider_id()
|
||||
}
|
||||
|
||||
//This function assumes that the provider has been set.
|
||||
fn update_upstream_path(&mut self, request_path: &str) {
|
||||
let hermes_provider_id = self.llm_provider().to_provider_id();
|
||||
if let Some(api) = &self.client_api {
|
||||
let target_endpoint =
|
||||
api.target_endpoint_for_provider(&hermes_provider_id, request_path);
|
||||
if target_endpoint != request_path {
|
||||
self.set_http_request_header(":path", Some(&target_endpoint));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn select_llm_provider(&mut self) {
|
||||
let provider_hint = self
|
||||
.get_http_request_header(ARCH_PROVIDER_HINT_HEADER)
|
||||
|
|
@ -95,28 +107,6 @@ impl StreamContext {
|
|||
provider_hint,
|
||||
));
|
||||
|
||||
match self.llm_provider.as_ref().unwrap().provider_interface {
|
||||
LlmProviderType::Groq => {
|
||||
if let Some(path) = self.get_http_request_header(":path") {
|
||||
if path.starts_with("/v1/") {
|
||||
let new_path = format!("/openai{}", path);
|
||||
self.set_http_request_header(":path", Some(new_path.as_str()));
|
||||
}
|
||||
}
|
||||
}
|
||||
LlmProviderType::Gemini => {
|
||||
if let Some(path) = self.get_http_request_header(":path") {
|
||||
if path == "/v1/chat/completions" {
|
||||
self.set_http_request_header(
|
||||
":path",
|
||||
Some("/v1beta/openai/chat/completions"),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
debug!(
|
||||
"request received: llm provider hint: {}, selected provider: {}",
|
||||
self.get_http_request_header(ARCH_PROVIDER_HINT_HEADER)
|
||||
|
|
@ -227,10 +217,35 @@ impl HttpContext for StreamContext {
|
|||
self.llm_provider = Some(Rc::new(LlmProvider {
|
||||
name: routing_header_value.to_string(),
|
||||
provider_interface: LlmProviderType::OpenAI,
|
||||
..Default::default()
|
||||
..Default::default() //TODO: THiS IS BROKEN. WHY ARE WE ASSUMING OPENAI FOR UPSTREAM?
|
||||
}));
|
||||
} else {
|
||||
//TODO: Fix this brittle code path. We need to return values and have compile time
|
||||
self.select_llm_provider();
|
||||
|
||||
// Check if this is a supported API endpoint
|
||||
if SupportedAPIs::from_endpoint(&request_path).is_none() {
|
||||
self.send_http_response(404, vec![], Some(b"Unsupported endpoint"));
|
||||
return Action::Continue;
|
||||
}
|
||||
|
||||
// Get the SupportedApi for routing decisions
|
||||
let supported_api: Option<SupportedAPIs> = SupportedAPIs::from_endpoint(&request_path);
|
||||
self.client_api = supported_api;
|
||||
|
||||
// Debug: log provider, client API, resolved API, and request path
|
||||
if let (Some(api), Some(provider)) =
|
||||
(self.client_api.as_ref(), self.llm_provider.as_ref())
|
||||
{
|
||||
let provider_id = provider.to_provider_id();
|
||||
self.resolved_api = Some(provider_id.compatible_api_for_client(api));
|
||||
} else {
|
||||
self.resolved_api = None;
|
||||
}
|
||||
|
||||
//We need to update the upstream path if there is a variation for a provider like Gemini/Groq, etc.
|
||||
self.update_upstream_path(&request_path);
|
||||
|
||||
if self.llm_provider().endpoint.is_some() {
|
||||
self.add_http_request_header(
|
||||
ARCH_ROUTING_HEADER,
|
||||
|
|
@ -257,62 +272,9 @@ impl HttpContext for StreamContext {
|
|||
self.delete_content_length_header();
|
||||
self.save_ratelimit_header();
|
||||
|
||||
// Apply provider-specific path routing
|
||||
match self.llm_provider.as_ref().unwrap().provider_interface {
|
||||
LlmProviderType::Groq => {
|
||||
if let Some(path) = self.get_http_request_header(":path") {
|
||||
if path.starts_with("/v1/") {
|
||||
let new_path = format!("/openai{}", path);
|
||||
self.set_http_request_header(":path", Some(new_path.as_str()));
|
||||
}
|
||||
}
|
||||
}
|
||||
LlmProviderType::Gemini => {
|
||||
if let Some(path) = self.get_http_request_header(":path") {
|
||||
if path == "/v1/chat/completions" {
|
||||
self.set_http_request_header(
|
||||
":path",
|
||||
Some("/v1beta/openai/chat/completions"),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
// Use SupportedApi for endpoint routing
|
||||
if let Some(api) = &self.client_api {
|
||||
let provider_name = &self.llm_provider.as_ref().unwrap().name;
|
||||
let target_endpoint = api.target_endpoint_for_provider(provider_name);
|
||||
// Only update path if it's different from the original
|
||||
if target_endpoint != request_path {
|
||||
self.set_http_request_header(":path", Some(target_endpoint));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
self.request_id = self.get_http_request_header(REQUEST_ID_HEADER);
|
||||
self.traceparent = self.get_http_request_header(TRACE_PARENT_HEADER);
|
||||
|
||||
// Check if this is a supported API endpoint
|
||||
if SupportedAPIs::from_endpoint(&request_path).is_none() {
|
||||
self.send_http_response(404, vec![], Some(b"Unsupported endpoint"));
|
||||
return Action::Continue;
|
||||
}
|
||||
|
||||
// Get the SupportedApi for routing decisions
|
||||
let supported_api: Option<SupportedAPIs> = SupportedAPIs::from_endpoint(&request_path);
|
||||
self.client_api = supported_api;
|
||||
|
||||
// Debug: log provider, client API, resolved API, and request path
|
||||
if let (Some(api), Some(provider)) = (self.client_api.as_ref(), self.llm_provider.as_ref())
|
||||
{
|
||||
let provider_id = provider.to_provider_id();
|
||||
let resolved_api = provider_id.compatible_api_for_client(api);
|
||||
self.resolved_api = Some(resolved_api);
|
||||
} else {
|
||||
self.resolved_api = None;
|
||||
}
|
||||
|
||||
Action::Continue
|
||||
}
|
||||
|
||||
|
|
@ -678,30 +640,24 @@ impl HttpContext for StreamContext {
|
|||
}
|
||||
} else {
|
||||
debug!("non streaming response");
|
||||
match (supported_api, self.resolved_api.as_ref()) {
|
||||
let provider_id = self.get_provider_id();
|
||||
let supported_api = self.client_api.as_ref();
|
||||
|
||||
let response: ProviderResponseType = match (supported_api, self.resolved_api.as_ref()) {
|
||||
(Some(supported_api), Some(_)) => {
|
||||
match ProviderResponseType::try_from((&body[..], supported_api, &provider_id)) {
|
||||
Ok(response) => match response.as_json_bytes() {
|
||||
Ok(bytes) => {
|
||||
self.set_http_response_body(0, bytes.len(), &bytes);
|
||||
}
|
||||
Err(e) => {
|
||||
self.send_server_error(
|
||||
ServerError::LogicError(format!(
|
||||
"Response serialization error: {}",
|
||||
e
|
||||
)),
|
||||
Some(StatusCode::BAD_REQUEST),
|
||||
);
|
||||
return Action::Continue;
|
||||
}
|
||||
},
|
||||
Ok(response) => response,
|
||||
Err(e) => {
|
||||
warn!(
|
||||
"could not parse response: {}, body str: {}",
|
||||
e,
|
||||
String::from_utf8_lossy(&body)
|
||||
);
|
||||
debug!(
|
||||
"on_http_response_body: S[{}], response body: {}",
|
||||
self.context_id,
|
||||
String::from_utf8_lossy(&body)
|
||||
);
|
||||
self.send_server_error(
|
||||
ServerError::LogicError(format!("Response parsing error: {}", e)),
|
||||
Some(StatusCode::BAD_REQUEST),
|
||||
|
|
@ -712,7 +668,21 @@ impl HttpContext for StreamContext {
|
|||
}
|
||||
_ => {
|
||||
warn!("Missing supported_api or resolved_api for non-streaming response");
|
||||
return Action::Continue;
|
||||
}
|
||||
};
|
||||
|
||||
// Use provider interface to extract usage information
|
||||
if let Some((prompt_tokens, completion_tokens, total_tokens)) =
|
||||
response.extract_usage_counts()
|
||||
{
|
||||
debug!(
|
||||
"Response usage: prompt={}, completion={}, total={}",
|
||||
prompt_tokens, completion_tokens, total_tokens
|
||||
);
|
||||
self.response_tokens = completion_tokens;
|
||||
} else {
|
||||
warn!("No usage information found in response");
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -665,9 +665,6 @@ impl StreamContext {
|
|||
}
|
||||
|
||||
pub fn default_target_handler(&self, body: Vec<u8>, mut callout_context: StreamCallContext) {
|
||||
// Debug: print raw bytes in hex to diagnose extra data
|
||||
debug!("raw upstream response bytes (hex): {}",
|
||||
body.iter().map(|b| format!("{:02x}", b)).collect::<Vec<_>>().join(" "));
|
||||
let prompt_target = self
|
||||
.prompt_targets
|
||||
.get(callout_context.prompt_target_name.as_ref().unwrap())
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue