transformations are working. Now need to add some tests next

This commit is contained in:
Salman Paracha 2025-08-22 14:36:46 -07:00
parent 0aa9243093
commit e73a9eb61c
6 changed files with 182 additions and 207 deletions

View file

@ -31,14 +31,6 @@ pub enum SupportedApi {
}
impl SupportedApi {
/// Determine if a request/response conversion is required for the given model string
pub fn requires_conversion_for_model(&self, model: &str) -> bool {
use crate::providers::adapters::is_claude_family;
match self {
SupportedApi::Anthropic(AnthropicApi::Messages) => !is_claude_family(model),
SupportedApi::OpenAI(OpenAIApi::ChatCompletions) => is_claude_family(model),
}
}
/// Create a SupportedApi from an endpoint path
pub fn from_endpoint(endpoint: &str) -> Option<Self> {
if let Some(openai_api) = OpenAIApi::from_endpoint(endpoint) {
@ -60,14 +52,6 @@ impl SupportedApi {
}
}
/// Get the API family name
pub fn api_family(&self) -> &'static str {
match self {
SupportedApi::OpenAI(_) => "openai",
SupportedApi::Anthropic(_) => "anthropic",
}
}
/// Determine the target endpoint for a given provider
/// For /v1/messages: if provider is Anthropic, use /v1/messages; otherwise use /v1/chat/completions
pub fn target_endpoint_for_provider(&self, provider: &str) -> &'static str {
@ -83,23 +67,6 @@ impl SupportedApi {
_ => self.endpoint()
}
}
/// Check if request conversion is required for the given provider
/// True if we need to convert between Anthropic and OpenAI formats
pub fn requires_conversion(&self, provider: &str) -> bool {
match self {
SupportedApi::Anthropic(AnthropicApi::Messages) => {
// If provider is not Anthropic/Claude, we need to convert to OpenAI format
!(provider.to_lowercase().contains("anthropic") ||
provider.to_lowercase().contains("claude"))
}
SupportedApi::OpenAI(OpenAIApi::ChatCompletions) => {
// If provider is Anthropic/Claude but request is OpenAI format, need conversion
provider.to_lowercase().contains("anthropic") ||
provider.to_lowercase().contains("claude")
}
}
}
}

View file

@ -66,28 +66,30 @@ mod tests {
#[test]
fn test_provider_streaming_response() {
// Test streaming response parsing with sample SSE data
let sse_data = r#"data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268190,"model":"gpt-4","choices":[{"index":0,"delta":{"role":"assistant","content":"Hello"},"finish_reason":null}]}
let sse_data = r#"data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268190,"model":"gpt-4","choices":[{"index":0,"delta":{"role":"assistant","content":"Hello"},"finish_reason":null}]}
data: [DONE]
"#;
let result = ProviderStreamResponseIter::try_from((sse_data.as_bytes(), &ProviderId::OpenAI));
assert!(result.is_ok());
use crate::clients::endpoints::SupportedApi;
let api = SupportedApi::OpenAI(crate::apis::OpenAIApi::ChatCompletions);
let result = ProviderStreamResponseIter::try_from((sse_data.as_bytes(), &api, &ProviderId::OpenAI));
assert!(result.is_ok());
let mut streaming_response = result.unwrap();
let mut streaming_response = result.unwrap();
// Test that we can iterate over chunks - it's just an iterator now!
let first_chunk = streaming_response.next();
assert!(first_chunk.is_some());
// Test that we can iterate over chunks - it's just an iterator now!
let first_chunk = streaming_response.next();
assert!(first_chunk.is_some());
let chunk_result = first_chunk.unwrap();
assert!(chunk_result.is_ok());
let chunk_result = first_chunk.unwrap();
assert!(chunk_result.is_ok());
let chunk = chunk_result.unwrap();
assert_eq!(chunk.content_delta(), Some("Hello"));
assert!(!chunk.is_final());
let chunk = chunk_result.unwrap();
assert_eq!(chunk.content_delta(), Some("Hello"));
assert!(!chunk.is_final());
// Test that stream ends properly
// Test that stream ends properly
let final_chunk = streaming_response.next();
assert!(final_chunk.is_none());
}

View file

@ -1,4 +1,6 @@
use std::fmt::Display;
use crate::clients::endpoints::SupportedApi;
use crate::apis::{OpenAIApi, AnthropicApi};
/// Provider identifier enum - simple enum for identifying providers
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
@ -29,6 +31,22 @@ impl From<&str> for ProviderId {
}
}
impl ProviderId {
/// Given a client API, return the compatible upstream API for this provider
pub fn compatible_api_for_client(&self, client_api: &SupportedApi) -> SupportedApi {
match (self, client_api) {
// Claude/Anthropic providers natively support Anthropic APIs
(ProviderId::Claude, SupportedApi::Anthropic(_)) => client_api.clone(),
// Claude/Anthropic providers can also support OpenAI chat completions by mapping to Anthropic Messages
(ProviderId::Claude, SupportedApi::OpenAI(OpenAIApi::ChatCompletions)) => SupportedApi::Anthropic(AnthropicApi::Messages),
// OpenAI-compatible providers only support OpenAI chat completions
(ProviderId::OpenAI | ProviderId::Groq | ProviderId::Mistral | ProviderId::Deepseek | ProviderId::Arch | ProviderId::Gemini | ProviderId::GitHub, SupportedApi::Anthropic(_)) => SupportedApi::OpenAI(OpenAIApi::ChatCompletions),
(ProviderId::OpenAI | ProviderId::Groq | ProviderId::Mistral | ProviderId::Deepseek | ProviderId::Arch | ProviderId::Gemini | ProviderId::GitHub, SupportedApi::OpenAI(_)) => SupportedApi::OpenAI(OpenAIApi::ChatCompletions),
}
}
}
impl Display for ProviderId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {

View file

@ -2,7 +2,6 @@
use crate::apis::openai::ChatCompletionsRequest;
use crate::apis::anthropic::MessagesRequest;
use crate::clients::endpoints::SupportedApi;
use super::{ProviderId, get_provider_config, AdapterType};
use std::error::Error;
use std::fmt;
pub enum ProviderRequestType {
@ -22,53 +21,23 @@ impl TryFrom<&[u8]> for ProviderRequestType {
}
}
impl TryFrom<(&[u8], &ProviderId)> for ProviderRequestType {
/// Parse request based on endpoint and provider information
impl TryFrom<(&[u8], &SupportedApi)> for ProviderRequestType {
type Error = std::io::Error;
fn try_from((bytes, provider_id): (&[u8], &ProviderId)) -> Result<Self, Self::Error> {
let config = get_provider_config(provider_id);
match config.adapter_type {
AdapterType::OpenAICompatible => {
fn try_from((bytes, endpoint): (&[u8], &SupportedApi)) -> Result<Self, Self::Error> {
// Use SupportedApi to determine the appropriate request type
match endpoint {
SupportedApi::OpenAI(_) => {
let chat_completion_request: ChatCompletionsRequest = ChatCompletionsRequest::try_from(bytes)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
Ok(ProviderRequestType::ChatCompletionsRequest(chat_completion_request))
}
AdapterType::AnthropicCompatible => {
// For Anthropic providers, try to parse as MessagesRequest first, fallback to ChatCompletionsRequest
if let Ok(messages_request) = MessagesRequest::try_from(bytes) {
Ok(ProviderRequestType::MessagesRequest(messages_request))
} else {
let chat_completion_request: ChatCompletionsRequest = ChatCompletionsRequest::try_from(bytes)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
Ok(ProviderRequestType::ChatCompletionsRequest(chat_completion_request))
}
}
}
}
}
/// Parse request based on endpoint and provider information
impl TryFrom<(&[u8], &str, &ProviderId)> for ProviderRequestType {
type Error = std::io::Error;
fn try_from((bytes, endpoint, provider_id): (&[u8], &str, &ProviderId)) -> Result<Self, Self::Error> {
// Use SupportedApi to determine the appropriate request type
if let Some(api) = SupportedApi::from_endpoint(endpoint) {
match api {
SupportedApi::OpenAI(_) => {
let chat_completion_request: ChatCompletionsRequest = ChatCompletionsRequest::try_from(bytes)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
Ok(ProviderRequestType::ChatCompletionsRequest(chat_completion_request))
}
SupportedApi::Anthropic(_) => {
let messages_request: MessagesRequest = MessagesRequest::try_from(bytes)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
Ok(ProviderRequestType::MessagesRequest(messages_request))
}
}
} else {
// Fallback to provider-based parsing for unsupported endpoints
Self::try_from((bytes, provider_id))
}
}
}

View file

@ -1,11 +1,15 @@
use crate::providers::id::ProviderId;
use serde::Serialize;
use std::error::Error;
use std::fmt;
use crate::apis::openai::ChatCompletionsResponse;
use crate::apis::OpenAISseIter;
use crate::providers::id::ProviderId;
use crate::providers::adapters::{get_provider_config, AdapterType};
use crate::clients::endpoints::SupportedApi;
use std::convert::TryFrom;
#[derive(Serialize)]
pub enum ProviderResponseType {
ChatCompletionsResponse(ChatCompletionsResponse),
//MessagesResponse(MessagesResponse),
@ -16,51 +20,50 @@ pub enum ProviderStreamResponseIter {
//MessagesStream(AnthropicSseIter<std::vec::IntoIter<String>>),
}
impl TryFrom<(&[u8], ProviderId)> for ProviderResponseType {
// --- Response transformation logic for client API compatibility ---
impl TryFrom<(&[u8], &SupportedApi, &ProviderId)> for ProviderResponseType {
type Error = std::io::Error;
fn try_from((bytes, provider_id): (&[u8], ProviderId)) -> Result<Self, Self::Error> {
let config = get_provider_config(&provider_id);
match config.adapter_type {
AdapterType::OpenAICompatible => {
let chat_completions_response: ChatCompletionsResponse = ChatCompletionsResponse::try_from(bytes)
fn try_from((bytes, client_api, provider_id): (&[u8], &SupportedApi, &ProviderId)) -> Result<Self, Self::Error> {
let upstream_api = provider_id.compatible_api_for_client(client_api);
match (&upstream_api, client_api) {
(SupportedApi::OpenAI(_), SupportedApi::OpenAI(_)) => {
let resp: ChatCompletionsResponse = ChatCompletionsResponse::try_from(bytes)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
Ok(ProviderResponseType::ChatCompletionsResponse(chat_completions_response))
Ok(ProviderResponseType::ChatCompletionsResponse(resp))
}
AdapterType::AnthropicCompatible => {
// TODO: Implement MessagesResponse parsing for Anthropic-compatible providers
todo!("AnthropicCompatible response parsing not yet implemented");
(SupportedApi::OpenAI(_), SupportedApi::Anthropic(_)) => {
// If you add a MessagesResponse variant, return it here. For now, just error or serialize as needed.
Err(std::io::Error::new(std::io::ErrorKind::Other, "Anthropic response variant not implemented"))
}
_ => Err(std::io::Error::new(std::io::ErrorKind::Other, "Unsupported response transformation")),
}
}
}
impl TryFrom<(&[u8], &ProviderId)> for ProviderStreamResponseIter {
impl TryFrom<(&[u8], &SupportedApi, &ProviderId)> for ProviderStreamResponseIter {
type Error = Box<dyn std::error::Error + Send + Sync>;
fn try_from((bytes, provider_id): (&[u8], &ProviderId)) -> Result<Self, Self::Error> {
let config = get_provider_config(provider_id);
// Parse SSE (Server-Sent Events) streaming data - protocol layer
let s = std::str::from_utf8(bytes)?;
let lines: Vec<String> = s.lines().map(|line| line.to_string()).collect();
match config.adapter_type {
AdapterType::OpenAICompatible => {
// Delegate to OpenAI-specific iterator implementation
let sse_container = SseStreamIter::new(lines.into_iter());
fn try_from((bytes, client_api, provider_id): (&[u8], &SupportedApi, &ProviderId)) -> Result<Self, Self::Error> {
let upstream_api = provider_id.compatible_api_for_client(client_api);
match (&upstream_api, client_api) {
(SupportedApi::OpenAI(_), SupportedApi::OpenAI(_)) => {
let s = std::str::from_utf8(bytes)?;
let lines: Vec<String> = s.lines().map(|line| line.to_string()).collect();
let sse_container = crate::providers::response::SseStreamIter::new(lines.into_iter());
let iter = crate::apis::openai::OpenAISseIter::new(sse_container);
Ok(ProviderStreamResponseIter::ChatCompletionsStream(iter))
}
AdapterType::AnthropicCompatible => {
// TODO: Implement Anthropic streaming support
todo!("AnthropicCompatible streaming not yet implemented");
(SupportedApi::OpenAI(_), SupportedApi::Anthropic(_)) => {
// TODO: Implement streaming transformation from OpenAI to Anthropic
Err("Anthropic streaming response variant not implemented".into())
}
_ => Err("Unsupported streaming response transformation".into()),
}
}
}
impl Iterator for ProviderStreamResponseIter {
type Item = Result<Box<dyn ProviderStreamResponse>, Box<dyn std::error::Error + Send + Sync>>;

View file

@ -12,9 +12,7 @@ use common::tracing::{Event, Span, TraceData, Traceparent};
use common::{ratelimit, routing, tokenizer};
use hermesllm::clients::endpoints::SupportedApi;
use hermesllm::providers::response::ProviderStreamResponseIter;
use hermesllm::{
ProviderId, ProviderRequest, ProviderRequestType, ProviderResponse, ProviderResponseType,
};
use hermesllm::{ProviderId, ProviderRequest, ProviderRequestType, ProviderResponseType};
use http::StatusCode;
use log::{debug, info, warn};
use proxy_wasm::hostcalls::get_current_time;
@ -33,6 +31,8 @@ pub struct StreamContext {
streaming_response: bool,
response_tokens: usize,
supported_api: Option<SupportedApi>,
/// The API that should be used for the upstream provider (after compatibility mapping)
resolved_api: Option<SupportedApi>,
llm_providers: Rc<LlmProviders>,
llm_provider: Option<Rc<LlmProvider>>,
request_id: Option<String>,
@ -62,6 +62,7 @@ impl StreamContext {
streaming_response: false,
response_tokens: 0,
supported_api: None,
resolved_api: None,
llm_providers,
llm_provider: None,
request_id: None,
@ -223,6 +224,16 @@ impl HttpContext for StreamContext {
let supported_api = SupportedApi::from_endpoint(&request_path);
self.supported_api = supported_api;
// Determine the resolved (upstream) API using provider compatibility
if let (Some(api), Some(provider)) =
(self.supported_api.as_ref(), self.llm_provider.as_ref())
{
let provider_id = provider.to_provider_id();
self.resolved_api = Some(provider_id.compatible_api_for_client(api));
} else {
self.resolved_api = None;
}
let use_agent_orchestrator = match self.overrides.as_ref() {
Some(overrides) => overrides.use_agent_orchestrator.unwrap_or_default(),
None => false,
@ -340,22 +351,26 @@ impl HttpContext for StreamContext {
}
};
let provider_id = self.get_provider_id();
let request_path = self.get_http_request_header(":path").unwrap_or_default();
let mut deserialized_body = match ProviderRequestType::try_from((
&body_bytes[..],
request_path.as_str(),
&provider_id,
)) {
Ok(deserialized) => deserialized,
Err(e) => {
debug!(
"on_http_request_body: request body: {}",
String::from_utf8_lossy(&body_bytes)
);
let mut deserialized_body = match self.resolved_api.as_ref() {
Some(resolved_api) => {
match ProviderRequestType::try_from((&body_bytes[..], resolved_api)) {
Ok(deserialized) => deserialized,
Err(e) => {
debug!(
"on_http_request_body: request body: {}",
String::from_utf8_lossy(&body_bytes)
);
self.send_server_error(
ServerError::LogicError(format!("Request parsing error: {}", e)),
Some(StatusCode::BAD_REQUEST),
);
return Action::Pause;
}
}
}
None => {
self.send_server_error(
ServerError::LogicError(format!("Request parsing error: {}", e)),
ServerError::LogicError("No resolved API for provider".to_string()),
Some(StatusCode::BAD_REQUEST),
);
return Action::Pause;
@ -603,99 +618,100 @@ impl HttpContext for StreamContext {
);
}
let provider_id = self.get_provider_id();
let supported_api = self.supported_api.as_ref();
if self.streaming_response {
debug!("processing streaming response");
match ProviderStreamResponseIter::try_from((&body[..], &self.get_provider_id())) {
Ok(mut streaming_response) => {
// Process each streaming chunk
while let Some(chunk_result) = streaming_response.next() {
match chunk_result {
Ok(chunk) => {
// Compute TTFT on first chunk
if self.ttft_duration.is_none() {
let current_time = get_current_time().unwrap();
self.ttft_time = Some(current_time_ns());
match current_time.duration_since(self.start_time) {
Ok(duration) => {
let duration_ms = duration.as_millis();
info!(
"on_http_response_body: time to first token: {}ms",
duration_ms
);
self.ttft_duration = Some(duration);
self.metrics
.time_to_first_token
.record(duration_ms as u64);
match (supported_api, self.resolved_api.as_ref()) {
(Some(supported_api), Some(_)) => {
match ProviderStreamResponseIter::try_from((
&body[..],
supported_api,
&provider_id,
)) {
Ok(mut streaming_response) => {
while let Some(chunk_result) = streaming_response.next() {
match chunk_result {
Ok(chunk) => {
if self.ttft_duration.is_none() {
let current_time = get_current_time().unwrap();
self.ttft_time = Some(current_time_ns());
match current_time.duration_since(self.start_time) {
Ok(duration) => {
let duration_ms = duration.as_millis();
info!("on_http_response_body: time to first token: {}ms", duration_ms);
self.ttft_duration = Some(duration);
self.metrics
.time_to_first_token
.record(duration_ms as u64);
}
Err(e) => {
warn!("SystemTime error: {:?}", e);
}
}
}
Err(e) => {
warn!("SystemTime error: {:?}", e);
if chunk.is_final() {
debug!("Received final streaming chunk");
}
if let Some(content) = chunk.content_delta() {
let estimated_tokens = content.len() / 4;
self.response_tokens += estimated_tokens.max(1);
}
}
}
// For streaming responses, we handle token counting differently
// The ProviderStreamResponse trait provides content_delta, is_final, and role
// Token counting for streaming responses typically happens with final usage chunk
if chunk.is_final() {
// For now, we'll implement basic token estimation
// In a complete implementation, the final chunk would contain usage information
debug!("Received final streaming chunk");
}
// For now, estimate tokens from content delta
if let Some(content) = chunk.content_delta() {
// Rough estimation: ~4 characters per token
let estimated_tokens = content.len() / 4;
self.response_tokens += estimated_tokens.max(1);
Err(e) => {
warn!("Error processing streaming chunk: {}", e);
return Action::Continue;
}
}
}
Err(e) => {
warn!("Error processing streaming chunk: {}", e);
return Action::Continue;
}
}
Err(e) => {
warn!("Failed to parse streaming response: {}", e);
}
}
}
Err(e) => {
warn!("Failed to parse streaming response: {}", e);
_ => {
warn!("Missing supported_api or resolved_api for streaming response");
}
}
} else {
debug!("non streaming response");
let provider_id = self.get_provider_id();
let response: ProviderResponseType =
match ProviderResponseType::try_from((&body[..], provider_id)) {
Ok(response) => response,
Err(e) => {
warn!(
"could not parse response: {}, body str: {}",
e,
String::from_utf8_lossy(&body)
);
debug!(
"on_http_response_body: S[{}], response body: {}",
self.context_id,
String::from_utf8_lossy(&body)
);
self.send_server_error(
ServerError::LogicError(format!("Response parsing error: {}", e)),
Some(StatusCode::BAD_REQUEST),
);
return Action::Continue;
match (supported_api, self.resolved_api.as_ref()) {
(Some(supported_api), Some(_)) => {
match ProviderResponseType::try_from((&body[..], supported_api, &provider_id)) {
Ok(response) => match serde_json::to_vec(&response) {
Ok(bytes) => {
self.set_http_response_body(0, bytes.len(), &bytes);
}
Err(e) => {
self.send_server_error(
ServerError::LogicError(format!(
"Response serialization error: {}",
e
)),
Some(StatusCode::BAD_REQUEST),
);
return Action::Continue;
}
},
Err(e) => {
warn!(
"could not parse response: {}, body str: {}",
e,
String::from_utf8_lossy(&body)
);
self.send_server_error(
ServerError::LogicError(format!("Response parsing error: {}", e)),
Some(StatusCode::BAD_REQUEST),
);
return Action::Continue;
}
}
};
// Use provider interface to extract usage information
if let Some((prompt_tokens, completion_tokens, total_tokens)) =
response.extract_usage_counts()
{
debug!(
"Response usage: prompt={}, completion={}, total={}",
prompt_tokens, completion_tokens, total_tokens
);
self.response_tokens = completion_tokens;
} else {
warn!("No usage information found in response");
}
_ => {
warn!("Missing supported_api or resolved_api for non-streaming response");
}
}
}