updating the implementation of /v1/chat/completions to use the generic provider interfaces

This commit is contained in:
Salman Paracha 2025-08-08 23:17:29 -07:00
parent 93ff4d7b1f
commit 203fc8f9a9
8 changed files with 441 additions and 89 deletions

View file

@ -177,6 +177,18 @@ impl Display for LlmProviderType {
}
}
impl LlmProviderType {
/// Create a ProviderInstance from this LlmProviderType
/// This is the main method for stream_context to get provider-specific interfaces
pub fn create_provider_instance(&self) -> hermesllm::ProviderInstance {
use hermesllm::ProviderInstance;
// For now, all providers use OpenAI-compatible APIs
// TODO: Return specific provider instances when implementing different APIs
ProviderInstance::from_name(&self.to_string())
}
}
#[derive(Serialize, Deserialize, Debug)]
pub struct ModelUsagePreference {
pub model: String,
@ -252,6 +264,14 @@ impl Display for LlmProvider {
}
}
impl LlmProvider {
/// Create a ProviderInstance for this LlmProvider
/// This is a convenience method that delegates to the provider_interface
pub fn create_provider_instance(&self) -> hermesllm::ProviderInstance {
self.provider_interface.create_provider_instance()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Endpoint {
pub endpoint: Option<String>,

View file

@ -5,6 +5,11 @@ pub mod providers;
pub mod apis;
pub mod clients;
// Re-export important traits
pub use providers::traits::*;
pub use providers::openai::provider::OpenAIProvider;
pub use providers::provider_enum::ProviderInstance;
use std::fmt::Display;
pub enum Provider {
@ -34,6 +39,45 @@ impl From<&str> for Provider {
}
}
impl Provider {
/// Get the API endpoint path for this provider
pub fn api_path(&self) -> &'static str {
match self {
Provider::OpenAI => "/v1/chat/completions",
Provider::Groq => "/openai/v1/chat/completions", // Groq maps to OpenAI-compatible endpoint
Provider::Gemini => "/v1/models", // TODO: Update with correct Gemini path
Provider::Claude => "/v1/messages", // TODO: Update with correct Claude path
Provider::Mistral => "/v1/chat/completions", // Mistral uses OpenAI-compatible API
Provider::Deepseek => "/v1/chat/completions", // DeepSeek uses OpenAI-compatible API
Provider::Arch => "/v1/chat/completions", // Arch gateway endpoint
Provider::Github => "/models", // TODO: Update with correct GitHub models path
}
}
/// Check if this provider uses OpenAI-compatible API format
pub fn uses_openai_format(&self) -> bool {
match self {
Provider::OpenAI | Provider::Groq | Provider::Mistral | Provider::Deepseek | Provider::Arch => true,
Provider::Gemini | Provider::Claude | Provider::Github => false, // These have their own formats
}
}
/// Create a provider implementation instance for this provider
pub fn create_provider_instance(&self) -> ProviderInstance {
match self {
Provider::OpenAI => ProviderInstance::OpenAI(OpenAIProvider),
Provider::Groq => ProviderInstance::OpenAI(OpenAIProvider), // Groq uses OpenAI-compatible API
Provider::Mistral => ProviderInstance::OpenAI(OpenAIProvider), // Mistral uses OpenAI-compatible API
Provider::Deepseek => ProviderInstance::OpenAI(OpenAIProvider), // Deepseek uses OpenAI-compatible API
Provider::Arch => ProviderInstance::OpenAI(OpenAIProvider), // Arch gateway uses OpenAI-compatible API
// TODO: Implement specific providers for these when they have different APIs
Provider::Gemini => ProviderInstance::OpenAI(OpenAIProvider), // For now, use OpenAI-compatible
Provider::Claude => ProviderInstance::OpenAI(OpenAIProvider), // For now, use OpenAI-compatible
Provider::Github => ProviderInstance::OpenAI(OpenAIProvider), // For now, use OpenAI-compatible
}
}
}
impl Display for Provider {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {

View file

@ -1 +1,3 @@
pub mod openai;
pub mod traits;
pub mod provider_enum;

View file

@ -1,2 +1,3 @@
pub mod builder;
pub mod types;
pub mod provider;

View file

@ -0,0 +1,171 @@
//! OpenAI provider interface implementations
use crate::apis::openai::*;
use crate::providers::traits::*;
use crate::Provider;
// Simple error type for OpenAI API operations
#[derive(Debug, thiserror::Error)]
pub enum OpenAIApiError {
#[error("JSON parsing error: {0}")]
JsonError(#[from] serde_json::Error),
#[error("UTF-8 parsing error: {0}")]
Utf8Error(#[from] std::str::Utf8Error),
#[error("Invalid streaming data: {0}")]
InvalidStreamingData(String),
#[error("Request conversion error: {0}")]
RequestConversionError(String),
}
// ============================================================================
// OpenAI Provider Definition
// ============================================================================
pub struct OpenAIProvider;
// Create a concrete streaming response type to avoid lifetime issues
pub struct OpenAIStreamingResponse {
lines: Vec<String>,
current_index: usize,
}
impl OpenAIStreamingResponse {
fn new(data: String) -> Self {
let lines: Vec<String> = data.lines().map(|s| s.to_string()).collect();
Self {
lines,
current_index: 0,
}
}
}
impl Iterator for OpenAIStreamingResponse {
type Item = Result<ChatCompletionsStreamResponse, OpenAIApiError>;
fn next(&mut self) -> Option<Self::Item> {
while self.current_index < self.lines.len() {
let line = &self.lines[self.current_index];
self.current_index += 1;
if let Some(data) = line.strip_prefix("data: ") {
let data = data.trim();
if data == "[DONE]" {
return None;
}
if data == r#"{"type": "ping"}"# {
continue; // Skip ping messages
}
return Some(
serde_json::from_str::<ChatCompletionsStreamResponse>(data).map_err(|e| {
OpenAIApiError::InvalidStreamingData(format!("Error parsing: {}, data: {}", e, data))
}),
);
}
}
None
}
}
impl ProviderInterface for OpenAIProvider {
type Request = ChatCompletionsRequest;
type Response = ChatCompletionsResponse;
type StreamingResponse = OpenAIStreamingResponse;
type Usage = Usage;
}
// ============================================================================
// Trait Implementations for OpenAI Types
// ============================================================================
impl ProviderRequest for ChatCompletionsRequest {
type Error = OpenAIApiError;
fn try_from_bytes(bytes: &[u8]) -> Result<Self, Self::Error> {
let s = std::str::from_utf8(bytes)?;
Ok(serde_json::from_str(s)?)
}
fn to_provider_bytes(&self, _provider: Provider) -> Result<Vec<u8>, Self::Error> {
Ok(serde_json::to_vec(self)?)
}
fn extract_model(&self) -> &str {
&self.model
}
fn is_streaming(&self) -> bool {
self.stream.unwrap_or_default()
}
fn set_streaming_options(&mut self) {
if self.stream_options.is_none() {
self.stream_options = Some(StreamOptions {
include_usage: Some(true),
});
}
}
fn extract_messages_text(&self) -> String {
self.messages
.iter()
.fold(String::new(), |acc, m| {
acc + " " + &match &m.content {
MessageContent::Text(text) => text.clone(),
MessageContent::Parts(parts) => {
parts.iter().map(|part| match part {
ContentPart::Text { text } => text.clone(),
ContentPart::ImageUrl { .. } => "[Image]".to_string(),
}).collect::<Vec<_>>().join(" ")
}
}
})
}
}
impl TokenUsage for Usage {
fn completion_tokens(&self) -> usize {
self.completion_tokens as usize
}
fn prompt_tokens(&self) -> usize {
self.prompt_tokens as usize
}
fn total_tokens(&self) -> usize {
self.total_tokens as usize
}
}
impl ProviderResponse for ChatCompletionsResponse {
type Error = OpenAIApiError;
type Usage = Usage;
fn try_from_bytes(bytes: &[u8], _provider: &Provider) -> Result<Self, Self::Error> {
let s = std::str::from_utf8(bytes)?;
Ok(serde_json::from_str(s)?)
}
fn usage(&self) -> Option<&Self::Usage> {
Some(&self.usage)
}
}
impl StreamChunk for ChatCompletionsStreamResponse {
type Usage = Usage;
fn usage(&self) -> Option<&Self::Usage> {
self.usage.as_ref()
}
}
impl StreamingResponse for OpenAIStreamingResponse {
type Error = OpenAIApiError;
type Chunk = ChatCompletionsStreamResponse;
fn try_from_bytes(bytes: &[u8], _provider: &Provider) -> Result<Self, Self::Error> {
let s = std::str::from_utf8(bytes)?;
Ok(OpenAIStreamingResponse::new(s.to_string()))
}
}

View file

@ -0,0 +1,67 @@
use crate::providers::traits::*;
use crate::providers::openai::provider::{OpenAIProvider, OpenAIStreamingResponse};
use crate::apis::openai::{ChatCompletionsRequest, ChatCompletionsResponse, Usage};
/// Enum that wraps all possible providers for dynamic dispatch
pub enum ProviderInstance {
OpenAI(OpenAIProvider),
// TODO: Add other providers as they are implemented
// Anthropic(AnthropicProvider),
// Mistral(MistralProvider),
// etc.
}
impl ProviderInstance {
/// Creates a provider from a provider name string
pub fn from_name(name: &str) -> Self {
match name.to_lowercase().as_str() {
"openai" | "groq" | "gemini" | "mistral" | "deepseek" | "arch" | "claude" => {
ProviderInstance::OpenAI(OpenAIProvider)
}
// TODO: Add other providers when implemented
// "claude" | "anthropic" => ProviderInstance::Anthropic(AnthropicProvider),
// "mistral" => ProviderInstance::Mistral(MistralProvider),
_ => {
// Default to OpenAI for unknown providers
ProviderInstance::OpenAI(OpenAIProvider)
}
}
}
/// Parse request from bytes using the appropriate provider
pub fn parse_request(&self, bytes: &[u8]) -> Result<ChatCompletionsRequest, Box<dyn std::error::Error + Send + Sync>> {
match self {
ProviderInstance::OpenAI(_) => {
ChatCompletionsRequest::try_from_bytes(bytes).map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
}
// TODO: Add other provider cases when implemented
}
}
/// Parse response from bytes using the appropriate provider
pub fn parse_response(&self, bytes: &[u8], provider: &crate::Provider) -> Result<ChatCompletionsResponse, Box<dyn std::error::Error + Send + Sync>> {
match self {
ProviderInstance::OpenAI(_) => {
ChatCompletionsResponse::try_from_bytes(bytes, provider).map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
}
// TODO: Add other provider cases when implemented
}
}
/// Parse streaming response from bytes using the appropriate provider
pub fn parse_streaming_response(&self, bytes: &[u8], provider: &crate::Provider) -> Result<OpenAIStreamingResponse, Box<dyn std::error::Error + Send + Sync>> {
match self {
ProviderInstance::OpenAI(_) => {
OpenAIStreamingResponse::try_from_bytes(bytes, provider).map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
}
// TODO: Add other provider cases when implemented
}
}
}
impl ProviderInterface for ProviderInstance {
type Request = ChatCompletionsRequest;
type Response = ChatCompletionsResponse;
type StreamingResponse = OpenAIStreamingResponse;
type Usage = Usage;
}

View file

@ -0,0 +1,74 @@
//! Provider traits for generic request/response handling
//!
//! This module defines the core traits that enable provider-agnostic
//! handling of LLM requests and responses in the gateway.
use std::error::Error;
use crate::Provider;
/// Trait for provider-specific request types
pub trait ProviderRequest: Sized {
type Error: Error + Send + Sync + 'static;
/// Parse request from raw bytes
fn try_from_bytes(bytes: &[u8]) -> Result<Self, Self::Error>;
/// Convert to provider-specific format
fn to_provider_bytes(&self, provider: Provider) -> Result<Vec<u8>, Self::Error>;
/// Extract the model name from the request
fn extract_model(&self) -> &str;
/// Check if this is a streaming request
fn is_streaming(&self) -> bool;
/// Set streaming options (e.g., include_usage)
fn set_streaming_options(&mut self);
/// Extract text content from messages for token counting
fn extract_messages_text(&self) -> String;
}
/// Trait for token usage information
pub trait TokenUsage {
fn completion_tokens(&self) -> usize;
fn prompt_tokens(&self) -> usize;
fn total_tokens(&self) -> usize;
}
/// Trait for provider-specific response types
pub trait ProviderResponse: Sized {
type Error: Error + Send + Sync + 'static;
type Usage: TokenUsage;
/// Parse response from raw bytes
fn try_from_bytes(bytes: &[u8], provider: &Provider) -> Result<Self, Self::Error>;
/// Get usage information if available
fn usage(&self) -> Option<&Self::Usage>;
}
/// Trait for streaming response chunks
pub trait StreamChunk {
type Usage: TokenUsage;
/// Get usage information if available
fn usage(&self) -> Option<&Self::Usage>;
}
/// Trait for streaming response iterators
pub trait StreamingResponse: Iterator<Item = Result<Self::Chunk, Self::Error>> + Sized {
type Error: Error + Send + Sync + 'static;
type Chunk: StreamChunk;
/// Parse streaming response from raw bytes
fn try_from_bytes(bytes: &[u8], provider: &Provider) -> Result<Self, Self::Error>;
}
/// Main provider interface trait
pub trait ProviderInterface {
type Request: ProviderRequest;
type Response: ProviderResponse;
type StreamingResponse: StreamingResponse;
type Usage: TokenUsage;
}

View file

@ -10,11 +10,9 @@ use common::ratelimit::Header;
use common::stats::{IncrementingMetric, RecordingMetric};
use common::tracing::{Event, Span, TraceData, Traceparent};
use common::{ratelimit, routing, tokenizer};
use hermesllm::providers::openai::types::{ChatCompletionsRequest, SseChatCompletionIter};
use hermesllm::providers::openai::types::{
ChatCompletionsResponse, ContentType, Message, StreamOptions,
use hermesllm::{
Provider, ProviderInstance, ProviderRequest, ProviderResponse, StreamChunk, TokenUsage,
};
use hermesllm::Provider;
use http::StatusCode;
use log::{debug, info, warn};
use proxy_wasm::hostcalls::get_current_time;
@ -41,7 +39,6 @@ pub struct StreamContext {
ttft_time: Option<u128>,
traceparent: Option<String>,
request_body_sent_time: Option<u128>,
user_message: Option<Message>,
traces_queue: Arc<Mutex<VecDeque<TraceData>>>,
overrides: Rc<Option<Overrides>>,
}
@ -69,7 +66,6 @@ impl StreamContext {
ttft_duration: None,
traceparent: None,
ttft_time: None,
user_message: None,
traces_queue,
request_body_sent_time: None,
}
@ -80,6 +76,10 @@ impl StreamContext {
.expect("the provider should be set when asked for it")
}
fn get_provider_instance(&self) -> ProviderInstance {
self.llm_provider().create_provider_instance()
}
fn select_llm_provider(&mut self) {
let provider_hint = self
.get_http_request_header(ARCH_PROVIDER_HINT_HEADER)
@ -295,52 +295,39 @@ impl HttpContext for StreamContext {
}
};
let mut deserialized_body = match ChatCompletionsRequest::try_from(body_bytes.as_slice()) {
let provider_instance = self.get_provider_instance();
let mut deserialized_body = match provider_instance.parse_request(&body_bytes) {
Ok(deserialized) => deserialized,
Err(e) => {
debug!(
"on_http_request_body: request body: {}",
String::from_utf8_lossy(&body_bytes)
);
self.send_server_error(ServerError::OpenAIPError(e), Some(StatusCode::BAD_REQUEST));
self.send_server_error(
ServerError::LogicError(format!("Request parsing error: {}", e)),
Some(StatusCode::BAD_REQUEST),
);
return Action::Pause;
}
};
self.user_message = deserialized_body
.messages
.iter()
.filter(|m| m.role == "user")
.last()
.cloned();
// TODO: For now, we'll need to handle user_message extraction differently since it's OpenAI-specific
// This could be made generic by adding a trait method later
let model_name = match self.llm_provider.as_ref() {
Some(llm_provider) => llm_provider.model.as_ref(),
None => None,
};
let use_agent_orchestrator = match self.overrides.as_ref() {
let _use_agent_orchestrator = match self.overrides.as_ref() {
Some(overrides) => overrides.use_agent_orchestrator.unwrap_or_default(),
None => false,
};
let model_requested = deserialized_body.model.clone();
deserialized_body.model = match model_name {
Some(model_name) => model_name.clone(),
None => {
if use_agent_orchestrator {
"agent_orchestrator".to_string()
} else {
self.send_server_error(
ServerError::BadRequest {
why: format!("No model specified in request and couldn't determine model name from arch_config. Model name in req: {}, arch_config, provider: {}, model: {:?}", deserialized_body.model, self.llm_provider().name, self.llm_provider().model).to_string(),
},
Some(StatusCode::BAD_REQUEST),
);
return Action::Continue;
}
}
};
let model_requested = deserialized_body.extract_model().to_string();
// Note: We can't directly modify the model field through the trait,
// this would need to be handled differently in a full generic implementation
info!(
"on_http_request_body: provider: {}, model requested (in body): {}, model selected: {}",
@ -349,32 +336,17 @@ impl HttpContext for StreamContext {
model_name.unwrap_or(&"None".to_string()),
);
if deserialized_body.stream.unwrap_or_default() {
if deserialized_body.is_streaming() {
self.streaming_response = true;
}
if deserialized_body.stream.unwrap_or_default()
&& deserialized_body.stream_options.is_none()
{
deserialized_body.stream_options = Some(StreamOptions {
include_usage: true,
});
if deserialized_body.is_streaming() {
deserialized_body.set_streaming_options();
}
// only use the tokens from the messages, excluding the metadata and json tags
let input_tokens_str = deserialized_body
.messages
.iter()
.fold(String::new(), |acc, m| {
acc + " "
+ m.content
.as_ref()
.unwrap_or(&ContentType::Text(String::new()))
.to_string()
.as_str()
});
let input_tokens_str = deserialized_body.extract_messages_text();
// enforce ratelimits on ingress
if let Err(e) = self.enforce_ratelimits(&deserialized_body.model, input_tokens_str.as_str())
{
if let Err(e) = self.enforce_ratelimits(&model_requested, input_tokens_str.as_str()) {
self.send_server_error(
ServerError::ExceededRatelimit(e),
Some(StatusCode::TOO_MANY_REQUESTS),
@ -387,11 +359,15 @@ impl HttpContext for StreamContext {
let hermes_llm_provider = Provider::from(llm_provider_str.as_str());
// convert chat completion request to llm provider specific request
let deserialized_body_bytes = match deserialized_body.to_bytes(hermes_llm_provider) {
let deserialized_body_bytes = match deserialized_body.to_provider_bytes(hermes_llm_provider)
{
Ok(bytes) => bytes,
Err(e) => {
warn!("Failed to serialize request body: {}", e);
self.send_server_error(ServerError::OpenAIPError(e), Some(StatusCode::BAD_REQUEST));
self.send_server_error(
ServerError::LogicError(format!("Request serialization error: {}", e)),
Some(StatusCode::BAD_REQUEST),
);
return Action::Pause;
}
};
@ -484,12 +460,6 @@ impl HttpContext for StreamContext {
self.request_body_sent_time.unwrap(),
current_time_ns,
);
if let Some(user_message) = self.user_message.as_ref() {
if let Some(prompt) = user_message.content.as_ref() {
llm_span
.add_attribute("user_prompt".to_string(), prompt.to_string());
}
}
llm_span.add_attribute(
"model".to_string(),
self.llm_provider().name.to_string(),
@ -562,8 +532,11 @@ impl HttpContext for StreamContext {
let hermes_llm_provider = Provider::from(llm_provider_str.as_str());
if self.streaming_response {
let chat_completions_chunk_response_events =
match SseChatCompletionIter::try_from((body.as_slice(), &hermes_llm_provider)) {
// Use the provider instance to parse streaming response
let provider_instance = self.get_provider_instance();
let streaming_events =
match provider_instance.parse_streaming_response(&body, &hermes_llm_provider) {
Ok(events) => events,
Err(e) => {
warn!(
@ -575,11 +548,11 @@ impl HttpContext for StreamContext {
}
};
for event in chat_completions_chunk_response_events {
match event {
for event_result in streaming_events {
match event_result {
Ok(event) => {
if let Some(usage) = event.usage.as_ref() {
self.response_tokens += usage.completion_tokens;
if let Some(usage) = event.usage() {
self.response_tokens += usage.completion_tokens();
}
}
Err(e) => {
@ -611,30 +584,30 @@ impl HttpContext for StreamContext {
}
} else {
debug!("non streaming response");
let chat_completions_response =
match ChatCompletionsResponse::try_from((body.as_slice(), &hermes_llm_provider)) {
Ok(de) => de,
Err(e) => {
warn!(
"could not parse response: {}, body str: {}",
e,
String::from_utf8_lossy(&body)
);
debug!(
"on_http_response_body: S[{}], response body: {}",
self.context_id,
String::from_utf8_lossy(&body)
);
self.send_server_error(
ServerError::OpenAIPError(e),
Some(StatusCode::BAD_REQUEST),
);
return Action::Continue;
}
};
let provider_instance = self.get_provider_instance();
let response = match provider_instance.parse_response(&body, &hermes_llm_provider) {
Ok(de) => de,
Err(e) => {
warn!(
"could not parse response: {}, body str: {}",
e,
String::from_utf8_lossy(&body)
);
debug!(
"on_http_response_body: S[{}], response body: {}",
self.context_id,
String::from_utf8_lossy(&body)
);
self.send_server_error(
ServerError::LogicError(format!("Response parsing error: {}", e)),
Some(StatusCode::BAD_REQUEST),
);
return Action::Continue;
}
};
if let Some(usage) = chat_completions_response.usage {
self.response_tokens += usage.completion_tokens;
if let Some(usage) = response.usage() {
self.response_tokens += usage.completion_tokens();
}
}