saving changes, although we will need a small re-factor after this as well

This commit is contained in:
Salman Paracha 2025-08-09 11:19:23 -07:00
parent 203fc8f9a9
commit 63f23efda4
10 changed files with 414 additions and 259 deletions

View file

@ -178,14 +178,13 @@ impl Display for LlmProviderType {
}
impl LlmProviderType {
/// Create a ProviderInstance from this LlmProviderType
/// Create a Provider from this LlmProviderType
/// This is the main method for stream_context to get provider-specific interfaces
pub fn create_provider_instance(&self) -> hermesllm::ProviderInstance {
use hermesllm::ProviderInstance;
pub fn create_provider(&self) -> hermesllm::Provider {
use hermesllm::{ProviderId, Provider};
// For now, all providers use OpenAI-compatible APIs
// TODO: Return specific provider instances when implementing different APIs
ProviderInstance::from_name(&self.to_string())
let provider_id = ProviderId::from(self.to_string().as_str());
Provider::new(provider_id)
}
}
@ -265,10 +264,10 @@ impl Display for LlmProvider {
}
impl LlmProvider {
/// Create a ProviderInstance for this LlmProvider
/// Create a Provider for this LlmProvider
/// This is a convenience method that delegates to the provider_interface
pub fn create_provider_instance(&self) -> hermesllm::ProviderInstance {
self.provider_interface.create_provider_instance()
pub fn create_provider(&self) -> hermesllm::Provider {
self.provider_interface.create_provider()
}
}

View file

@ -5,121 +5,59 @@ pub mod providers;
pub mod apis;
pub mod clients;
// Re-export important traits
pub use providers::traits::*;
pub use providers::openai::provider::OpenAIProvider;
pub use providers::provider_enum::ProviderInstance;
use std::fmt::Display;
pub enum Provider {
Arch,
Mistral,
Deepseek,
Groq,
Gemini,
OpenAI,
Claude,
Github,
}
impl From<&str> for Provider {
fn from(value: &str) -> Self {
match value.to_lowercase().as_str() {
"arch" => Provider::Arch,
"mistral" => Provider::Mistral,
"deepseek" => Provider::Deepseek,
"groq" => Provider::Groq,
"gemini" => Provider::Gemini,
"openai" => Provider::OpenAI,
"claude" => Provider::Claude,
"github" => Provider::Github,
_ => panic!("Unknown provider: {}", value),
}
}
}
impl Provider {
/// Get the API endpoint path for this provider
pub fn api_path(&self) -> &'static str {
match self {
Provider::OpenAI => "/v1/chat/completions",
Provider::Groq => "/openai/v1/chat/completions", // Groq maps to OpenAI-compatible endpoint
Provider::Gemini => "/v1/models", // TODO: Update with correct Gemini path
Provider::Claude => "/v1/messages", // TODO: Update with correct Claude path
Provider::Mistral => "/v1/chat/completions", // Mistral uses OpenAI-compatible API
Provider::Deepseek => "/v1/chat/completions", // DeepSeek uses OpenAI-compatible API
Provider::Arch => "/v1/chat/completions", // Arch gateway endpoint
Provider::Github => "/models", // TODO: Update with correct GitHub models path
}
}
/// Check if this provider uses OpenAI-compatible API format
pub fn uses_openai_format(&self) -> bool {
match self {
Provider::OpenAI | Provider::Groq | Provider::Mistral | Provider::Deepseek | Provider::Arch => true,
Provider::Gemini | Provider::Claude | Provider::Github => false, // These have their own formats
}
}
/// Create a provider implementation instance for this provider
pub fn create_provider_instance(&self) -> ProviderInstance {
match self {
Provider::OpenAI => ProviderInstance::OpenAI(OpenAIProvider),
Provider::Groq => ProviderInstance::OpenAI(OpenAIProvider), // Groq uses OpenAI-compatible API
Provider::Mistral => ProviderInstance::OpenAI(OpenAIProvider), // Mistral uses OpenAI-compatible API
Provider::Deepseek => ProviderInstance::OpenAI(OpenAIProvider), // Deepseek uses OpenAI-compatible API
Provider::Arch => ProviderInstance::OpenAI(OpenAIProvider), // Arch gateway uses OpenAI-compatible API
// TODO: Implement specific providers for these when they have different APIs
Provider::Gemini => ProviderInstance::OpenAI(OpenAIProvider), // For now, use OpenAI-compatible
Provider::Claude => ProviderInstance::OpenAI(OpenAIProvider), // For now, use OpenAI-compatible
Provider::Github => ProviderInstance::OpenAI(OpenAIProvider), // For now, use OpenAI-compatible
}
}
}
impl Display for Provider {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Provider::Arch => write!(f, "Arch"),
Provider::Mistral => write!(f, "Mistral"),
Provider::Deepseek => write!(f, "Deepseek"),
Provider::Groq => write!(f, "Groq"),
Provider::Gemini => write!(f, "Gemini"),
Provider::OpenAI => write!(f, "OpenAI"),
Provider::Claude => write!(f, "Claude"),
Provider::Github => write!(f, "Github"),
}
}
}
// Re-export important types and traits
pub use providers::{
ProviderId, Provider, ConversionMode,
ProviderInterface, ProviderRequest, ProviderResponse,
TokenUsage, StreamChunk, StreamingResponse,
OpenAIProvider
};
#[cfg(test)]
mod tests {
use crate::providers::openai::types::{ChatCompletionsRequest, Message};
use super::*;
#[test]
fn openai_builder() {
let request =
ChatCompletionsRequest::builder("gpt-3.5-turbo", vec![Message::new("Hi".to_string())])
.temperature(0.7)
.top_p(0.9)
.n(1)
.max_tokens(100)
.stream(false)
.stop(vec!["\n".to_string()])
.presence_penalty(0.0)
.frequency_penalty(0.0)
.build()
.expect("Failed to build OpenAIRequest");
fn test_provider_id_conversion() {
assert_eq!(ProviderId::from("openai"), ProviderId::OpenAI);
assert_eq!(ProviderId::from("mistral"), ProviderId::Mistral);
assert_eq!(ProviderId::from("groq"), ProviderId::Groq);
assert_eq!(ProviderId::from("arch"), ProviderId::Arch);
}
assert_eq!(request.model, "gpt-3.5-turbo");
assert_eq!(request.temperature, Some(0.7));
assert_eq!(request.top_p, Some(0.9));
assert_eq!(request.n, Some(1));
assert_eq!(request.max_tokens, Some(100));
assert_eq!(request.stream, Some(false));
assert_eq!(request.stop, Some(vec!["\n".to_string()]));
assert_eq!(request.presence_penalty, Some(0.0));
assert_eq!(request.frequency_penalty, Some(0.0));
#[test]
fn test_provider_api_paths() {
assert_eq!(ProviderId::OpenAI.api_path(), "/v1/chat/completions");
assert_eq!(ProviderId::Groq.api_path(), "/openai/v1/chat/completions");
assert_eq!(ProviderId::Mistral.api_path(), "/v1/chat/completions");
assert_eq!(ProviderId::Arch.api_path(), "/v1/chat/completions");
}
#[test]
fn test_provider_openai_format_support() {
assert!(ProviderId::OpenAI.supports_openai_format());
assert!(ProviderId::Groq.supports_openai_format());
assert!(ProviderId::Mistral.supports_openai_format());
assert!(ProviderId::Arch.supports_openai_format());
assert!(!ProviderId::Gemini.supports_openai_format());
assert!(!ProviderId::Claude.supports_openai_format());
}
#[test]
fn test_provider_instance_creation() {
let provider = Provider::new(ProviderId::OpenAI);
assert!(provider.has_compatible_api("/v1/chat/completions"));
assert!(!provider.has_compatible_api("/v1/embeddings"));
}
#[test]
fn test_conversion_mode() {
let provider = Provider::new(ProviderId::OpenAI);
let compatible_mode = provider.get_interface(false);
assert!(matches!(compatible_mode, ConversionMode::Compatible));
let passthrough_mode = provider.get_interface(true);
assert!(matches!(passthrough_mode, ConversionMode::Passthrough));
}
}

View file

@ -0,0 +1,109 @@
//! Provider interface trait definitions
//!
//! This module defines the core traits that all LLM providers must implement.
//! The interface is designed around v1/chat/completions API for simplicity.
use std::error::Error;
/// Conversion mode for provider requests/responses
#[derive(Debug, Clone, Copy)]
pub enum ConversionMode {
/// Compatible: Convert between different provider formats to ensure compatibility
Compatible,
/// Passthrough: Pass requests/responses through with minimal modification
Passthrough,
}
/// Token usage information
pub trait TokenUsage {
fn completion_tokens(&self) -> usize;
fn prompt_tokens(&self) -> usize;
fn total_tokens(&self) -> usize;
}
/// Error type for provider operations
pub trait ProviderError: Error + Send + Sync + 'static {}
/// Request type that can be converted to/from provider-specific formats
pub trait ProviderRequest: Sized {
type Error: ProviderError;
/// Parse request from raw bytes (typically JSON)
fn from_bytes(bytes: &[u8]) -> Result<Self, Self::Error>;
/// Convert to bytes for sending to upstream API
fn to_bytes(&self, mode: ConversionMode) -> Result<Vec<u8>, Self::Error>;
/// Extract the model name from the request
fn model(&self) -> &str;
/// Check if this is a streaming request
fn is_streaming(&self) -> bool;
/// Set streaming options (e.g., include_usage)
fn set_streaming_options(&mut self);
/// Extract text content from messages for token counting
fn extract_text(&self) -> String;
}
/// Response type that can be converted to/from provider-specific formats
pub trait ProviderResponse: Sized {
type Error: ProviderError;
type Usage: TokenUsage;
/// Parse response from raw bytes (typically JSON)
fn from_bytes(bytes: &[u8], mode: ConversionMode) -> Result<Self, Self::Error>;
/// Convert to bytes for sending to client
fn to_bytes(&self) -> Result<Vec<u8>, Self::Error>;
/// Get usage information if available
fn usage(&self) -> Option<&Self::Usage>;
}
/// Streaming response chunk
pub trait StreamChunk: Sized {
type Error: ProviderError;
type Usage: TokenUsage;
/// Parse chunk from a line of streaming data
fn from_line(line: &str, mode: ConversionMode) -> Result<Option<Self>, Self::Error>;
/// Convert to line for sending to client
fn to_line(&self) -> Result<String, Self::Error>;
/// Get usage information if available (usually only in final chunk)
fn usage(&self) -> Option<&Self::Usage>;
/// Check if this is the final chunk in the stream
fn is_final(&self) -> bool;
}
/// Main provider interface
pub trait LLMProvider {
type Request: ProviderRequest;
type Response: ProviderResponse;
type StreamChunk: StreamChunk;
type Error: ProviderError;
/// Create a new instance of this provider
fn new() -> Self;
/// Get the supported API endpoints for this provider
fn supported_apis(&self) -> Vec<&'static str>;
/// Check if the provider supports v1/chat/completions API
fn supports_chat_completions(&self) -> bool {
self.supported_apis().contains(&"/v1/chat/completions")
}
/// Parse a request from raw bytes
fn parse_request(&self, bytes: &[u8]) -> Result<Self::Request, Self::Error>;
/// Parse a response from raw bytes
fn parse_response(&self, bytes: &[u8], mode: ConversionMode) -> Result<Self::Response, Self::Error>;
/// Parse streaming response chunks from raw data
fn parse_stream_chunk(&self, line: &str, mode: ConversionMode) -> Result<Option<Self::StreamChunk>, Self::Error>;
}

View file

@ -1,3 +1,174 @@
pub mod openai;
//! Provider implementations for different LLM APIs
//!
//! This module contains provider-specific implementations that handle
//! request/response conversion for different LLM service APIs.
pub mod traits;
pub mod provider_enum;
pub mod openai;
// Re-export the main interfaces
pub use traits::*;
pub use openai::OpenAIProvider;
use std::fmt::Display;
/// Provider identifier enum - simple enum for identifying providers
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum ProviderId {
OpenAI,
Mistral,
Deepseek,
Groq,
Gemini,
Claude,
GitHub,
Arch,
}
impl From<&str> for ProviderId {
fn from(value: &str) -> Self {
match value.to_lowercase().as_str() {
"openai" => ProviderId::OpenAI,
"mistral" => ProviderId::Mistral,
"deepseek" => ProviderId::Deepseek,
"groq" => ProviderId::Groq,
"gemini" => ProviderId::Gemini,
"claude" => ProviderId::Claude,
"github" => ProviderId::GitHub,
"arch" => ProviderId::Arch,
_ => panic!("Unknown provider: {}", value),
}
}
}
impl Display for ProviderId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ProviderId::OpenAI => write!(f, "OpenAI"),
ProviderId::Mistral => write!(f, "Mistral"),
ProviderId::Deepseek => write!(f, "Deepseek"),
ProviderId::Groq => write!(f, "Groq"),
ProviderId::Gemini => write!(f, "Gemini"),
ProviderId::Claude => write!(f, "Claude"),
ProviderId::GitHub => write!(f, "GitHub"),
ProviderId::Arch => write!(f, "Arch"),
}
}
}
impl ProviderId {
/// Get the API endpoint path for this provider
pub fn api_path(&self) -> &'static str {
match self {
ProviderId::OpenAI => "/v1/chat/completions",
ProviderId::Groq => "/openai/v1/chat/completions",
ProviderId::Gemini => "/v1/models", // TODO: Update when Gemini API is implemented
ProviderId::Claude => "/v1/messages", // TODO: Update when Claude API is implemented
ProviderId::Mistral => "/v1/chat/completions",
ProviderId::Deepseek => "/v1/chat/completions",
ProviderId::GitHub => "/models", // TODO: Update when GitHub models API is implemented
ProviderId::Arch => "/v1/chat/completions",
}
}
/// Check if this provider supports OpenAI v1/chat/completions API format
pub fn supports_openai_format(&self) -> bool {
matches!(
self,
ProviderId::OpenAI | ProviderId::Groq | ProviderId::Mistral | ProviderId::Deepseek | ProviderId::Arch
)
}
}
/// Enum for dynamic dispatch of provider instances
/// For now, most providers use OpenAI-compatible format
pub enum Provider {
OpenAI(OpenAIProvider, ProviderId),
// TODO: Add specific implementations when providers have different APIs
// Mistral(MistralProvider, ProviderId),
// Groq(GroqProvider, ProviderId),
// etc.
}
impl Provider {
/// Create a provider instance from a provider ID
pub fn new(id: ProviderId) -> Self {
match id {
// For now, all providers that support v1/chat/completions use OpenAI format
ProviderId::OpenAI | ProviderId::Groq | ProviderId::Mistral | ProviderId::Deepseek | ProviderId::Arch => {
Provider::OpenAI(OpenAIProvider, id)
}
// TODO: Implement specific providers when they have different APIs
ProviderId::Gemini | ProviderId::Claude | ProviderId::GitHub => {
Provider::OpenAI(OpenAIProvider, id) // Fallback to OpenAI for now
}
}
}
/// Get the provider ID
pub fn id(&self) -> ProviderId {
match self {
Provider::OpenAI(_, id) => *id,
}
}
/// Check if this provider has a compatible API with the client request
pub fn has_compatible_api(&self, api_path: &str) -> bool {
match self {
Provider::OpenAI(provider, _) => provider.has_compatible_api(api_path),
}
}
/// Get the interface implementation for this provider
pub fn get_interface(&self, passthrough: bool) -> ConversionMode {
match self {
Provider::OpenAI(provider, _) => provider.get_interface(passthrough),
}
}
/// Parse a request from raw bytes - returns the concrete OpenAI request type for now
pub fn parse_request(&self, bytes: &[u8]) -> Result<crate::apis::openai::ChatCompletionsRequest, Box<dyn std::error::Error + Send + Sync>> {
match self {
Provider::OpenAI(_, _) => {
use crate::apis::openai::ChatCompletionsRequest;
use crate::providers::traits::ProviderRequest;
match ChatCompletionsRequest::try_from_bytes(bytes) {
Ok(req) => Ok(req),
Err(e) => Err(Box::new(e)),
}
}
}
}
/// Parse a response from raw bytes - returns the concrete OpenAI response type for now
pub fn parse_response(&self, bytes: &[u8], mode: ConversionMode) -> Result<crate::apis::openai::ChatCompletionsResponse, Box<dyn std::error::Error + Send + Sync>> {
match self {
Provider::OpenAI(_, _) => {
use crate::apis::openai::ChatCompletionsResponse;
use crate::providers::traits::ProviderResponse;
let provider_id = self.id();
match ChatCompletionsResponse::try_from_bytes(bytes, &provider_id, mode) {
Ok(resp) => Ok(resp),
Err(e) => Err(Box::new(e)),
}
}
}
}
/// Convert a request to bytes for sending to upstream API
pub fn request_to_bytes(&self, request: &crate::apis::openai::ChatCompletionsRequest, mode: ConversionMode) -> Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>> {
match self {
Provider::OpenAI(_, _) => {
use crate::providers::traits::ProviderRequest;
let provider_id = self.id();
match request.to_provider_bytes(provider_id, mode) {
Ok(bytes) => Ok(bytes),
Err(e) => Err(Box::new(e)),
}
}
}
}
}

View file

@ -1,3 +1,6 @@
pub mod builder;
pub mod types;
pub mod provider;
// Re-export the main provider
pub use provider::OpenAIProvider;

View file

@ -2,7 +2,6 @@
use crate::apis::openai::*;
use crate::providers::traits::*;
use crate::Provider;
// Simple error type for OpenAI API operations
#[derive(Debug, thiserror::Error)]
@ -73,6 +72,14 @@ impl ProviderInterface for OpenAIProvider {
type Response = ChatCompletionsResponse;
type StreamingResponse = OpenAIStreamingResponse;
type Usage = Usage;
fn has_compatible_api(&self, api_path: &str) -> bool {
api_path == "/v1/chat/completions"
}
fn supported_apis(&self) -> Vec<&'static str> {
vec!["/v1/chat/completions"]
}
}
// ============================================================================
@ -87,7 +94,7 @@ impl ProviderRequest for ChatCompletionsRequest {
Ok(serde_json::from_str(s)?)
}
fn to_provider_bytes(&self, _provider: Provider) -> Result<Vec<u8>, Self::Error> {
fn to_provider_bytes(&self, _provider: super::super::ProviderId, _mode: ConversionMode) -> Result<Vec<u8>, Self::Error> {
Ok(serde_json::to_vec(self)?)
}
@ -142,7 +149,7 @@ impl ProviderResponse for ChatCompletionsResponse {
type Error = OpenAIApiError;
type Usage = Usage;
fn try_from_bytes(bytes: &[u8], _provider: &Provider) -> Result<Self, Self::Error> {
fn try_from_bytes(bytes: &[u8], _provider: &super::super::ProviderId, _mode: ConversionMode) -> Result<Self, Self::Error> {
let s = std::str::from_utf8(bytes)?;
Ok(serde_json::from_str(s)?)
}
@ -164,7 +171,7 @@ impl StreamingResponse for OpenAIStreamingResponse {
type Error = OpenAIApiError;
type Chunk = ChatCompletionsStreamResponse;
fn try_from_bytes(bytes: &[u8], _provider: &Provider) -> Result<Self, Self::Error> {
fn try_from_bytes(bytes: &[u8], _provider: &super::super::ProviderId, _mode: ConversionMode) -> Result<Self, Self::Error> {
let s = std::str::from_utf8(bytes)?;
Ok(OpenAIStreamingResponse::new(s.to_string()))
}

View file

@ -8,7 +8,7 @@ use std::convert::TryFrom;
use std::str;
use thiserror::Error;
use crate::Provider;
use crate::providers::ProviderId;
#[derive(Debug, Error)]
pub enum OpenAIError {
@ -144,28 +144,26 @@ impl TryFrom<&[u8]> for ChatCompletionsResponse {
}
}
impl<'a> TryFrom<(&'a [u8], &'a Provider)> for ChatCompletionsResponse {
impl<'a> TryFrom<(&'a [u8], &'a ProviderId)> for ChatCompletionsResponse {
type Error = OpenAIError;
fn try_from(input: (&'a [u8], &'a Provider)) -> Result<Self> {
fn try_from(input: (&'a [u8], &'a ProviderId)) -> Result<Self> {
// Use input.provider as needed, if necessary
serde_json::from_slice(input.0).map_err(OpenAIError::from)
}
}
impl ChatCompletionsRequest {
pub fn to_bytes(&self, provider: Provider) -> Result<Vec<u8>> {
pub fn to_bytes(&self, provider: ProviderId) -> Result<Vec<u8>> {
match provider {
Provider::OpenAI
| Provider::Arch
| Provider::Deepseek
| Provider::Mistral
| Provider::Groq
| Provider::Gemini
| Provider::Claude => serde_json::to_vec(self).map_err(OpenAIError::from),
_ => Err(OpenAIError::UnsupportedProvider {
provider: provider.to_string(),
}),
ProviderId::OpenAI
| ProviderId::Arch
| ProviderId::Deepseek
| ProviderId::Mistral
| ProviderId::Groq
| ProviderId::Gemini
| ProviderId::Claude
| ProviderId::GitHub => serde_json::to_vec(self).map_err(OpenAIError::from),
}
}
}
@ -262,10 +260,10 @@ where
}
}
impl<'a> TryFrom<(&'a [u8], &'a Provider)> for SseChatCompletionIter<str::Lines<'a>> {
impl<'a> TryFrom<(&'a [u8], &'a ProviderId)> for SseChatCompletionIter<str::Lines<'a>> {
type Error = OpenAIError;
fn try_from(input: (&'a [u8], &'a Provider)) -> Result<Self> {
fn try_from(input: (&'a [u8], &'a ProviderId)) -> Result<Self> {
let s = std::str::from_utf8(input.0)?;
// Use input.provider as needed
Ok(SseChatCompletionIter::new(s.lines()))

View file

@ -1,67 +0,0 @@
use crate::providers::traits::*;
use crate::providers::openai::provider::{OpenAIProvider, OpenAIStreamingResponse};
use crate::apis::openai::{ChatCompletionsRequest, ChatCompletionsResponse, Usage};
/// Enum that wraps all possible providers for dynamic dispatch
pub enum ProviderInstance {
OpenAI(OpenAIProvider),
// TODO: Add other providers as they are implemented
// Anthropic(AnthropicProvider),
// Mistral(MistralProvider),
// etc.
}
impl ProviderInstance {
/// Creates a provider from a provider name string
pub fn from_name(name: &str) -> Self {
match name.to_lowercase().as_str() {
"openai" | "groq" | "gemini" | "mistral" | "deepseek" | "arch" | "claude" => {
ProviderInstance::OpenAI(OpenAIProvider)
}
// TODO: Add other providers when implemented
// "claude" | "anthropic" => ProviderInstance::Anthropic(AnthropicProvider),
// "mistral" => ProviderInstance::Mistral(MistralProvider),
_ => {
// Default to OpenAI for unknown providers
ProviderInstance::OpenAI(OpenAIProvider)
}
}
}
/// Parse request from bytes using the appropriate provider
pub fn parse_request(&self, bytes: &[u8]) -> Result<ChatCompletionsRequest, Box<dyn std::error::Error + Send + Sync>> {
match self {
ProviderInstance::OpenAI(_) => {
ChatCompletionsRequest::try_from_bytes(bytes).map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
}
// TODO: Add other provider cases when implemented
}
}
/// Parse response from bytes using the appropriate provider
pub fn parse_response(&self, bytes: &[u8], provider: &crate::Provider) -> Result<ChatCompletionsResponse, Box<dyn std::error::Error + Send + Sync>> {
match self {
ProviderInstance::OpenAI(_) => {
ChatCompletionsResponse::try_from_bytes(bytes, provider).map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
}
// TODO: Add other provider cases when implemented
}
}
/// Parse streaming response from bytes using the appropriate provider
pub fn parse_streaming_response(&self, bytes: &[u8], provider: &crate::Provider) -> Result<OpenAIStreamingResponse, Box<dyn std::error::Error + Send + Sync>> {
match self {
ProviderInstance::OpenAI(_) => {
OpenAIStreamingResponse::try_from_bytes(bytes, provider).map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
}
// TODO: Add other provider cases when implemented
}
}
}
impl ProviderInterface for ProviderInstance {
type Request = ChatCompletionsRequest;
type Response = ChatCompletionsResponse;
type StreamingResponse = OpenAIStreamingResponse;
type Usage = Usage;
}

View file

@ -4,7 +4,15 @@
//! handling of LLM requests and responses in the gateway.
use std::error::Error;
use crate::Provider;
/// Conversion mode for provider requests/responses
#[derive(Debug, Clone, Copy)]
pub enum ConversionMode {
/// Compatible: Convert between different provider formats to ensure compatibility
Compatible,
/// Passthrough: Pass requests/responses through with minimal modification
Passthrough,
}
/// Trait for provider-specific request types
pub trait ProviderRequest: Sized {
@ -14,7 +22,7 @@ pub trait ProviderRequest: Sized {
fn try_from_bytes(bytes: &[u8]) -> Result<Self, Self::Error>;
/// Convert to provider-specific format
fn to_provider_bytes(&self, provider: Provider) -> Result<Vec<u8>, Self::Error>;
fn to_provider_bytes(&self, provider: super::ProviderId, mode: ConversionMode) -> Result<Vec<u8>, Self::Error>;
/// Extract the model name from the request
fn extract_model(&self) -> &str;
@ -42,7 +50,7 @@ pub trait ProviderResponse: Sized {
type Usage: TokenUsage;
/// Parse response from raw bytes
fn try_from_bytes(bytes: &[u8], provider: &Provider) -> Result<Self, Self::Error>;
fn try_from_bytes(bytes: &[u8], provider: &super::ProviderId, mode: ConversionMode) -> Result<Self, Self::Error>;
/// Get usage information if available
fn usage(&self) -> Option<&Self::Usage>;
@ -62,7 +70,7 @@ pub trait StreamingResponse: Iterator<Item = Result<Self::Chunk, Self::Error>> +
type Chunk: StreamChunk;
/// Parse streaming response from raw bytes
fn try_from_bytes(bytes: &[u8], provider: &Provider) -> Result<Self, Self::Error>;
fn try_from_bytes(bytes: &[u8], provider: &super::ProviderId, mode: ConversionMode) -> Result<Self, Self::Error>;
}
/// Main provider interface trait
@ -71,4 +79,20 @@ pub trait ProviderInterface {
type Response: ProviderResponse;
type StreamingResponse: StreamingResponse;
type Usage: TokenUsage;
/// Check if this provider has a compatible API with the client request
fn has_compatible_api(&self, api_path: &str) -> bool;
/// Get the interface implementation for this provider
/// passthrough: if true, use provider-specific format; if false, use compatible format
fn get_interface(&self, passthrough: bool) -> ConversionMode {
if passthrough {
ConversionMode::Passthrough
} else {
ConversionMode::Compatible
}
}
/// Get supported API endpoints for this provider
fn supported_apis(&self) -> Vec<&'static str>;
}

View file

@ -10,9 +10,7 @@ use common::ratelimit::Header;
use common::stats::{IncrementingMetric, RecordingMetric};
use common::tracing::{Event, Span, TraceData, Traceparent};
use common::{ratelimit, routing, tokenizer};
use hermesllm::{
Provider, ProviderInstance, ProviderRequest, ProviderResponse, StreamChunk, TokenUsage,
};
use hermesllm::{ConversionMode, Provider, ProviderId, ProviderRequest};
use http::StatusCode;
use log::{debug, info, warn};
use proxy_wasm::hostcalls::get_current_time;
@ -76,8 +74,8 @@ impl StreamContext {
.expect("the provider should be set when asked for it")
}
fn get_provider_instance(&self) -> ProviderInstance {
self.llm_provider().create_provider_instance()
fn get_provider(&self) -> Provider {
self.llm_provider().create_provider()
}
fn select_llm_provider(&mut self) {
@ -295,9 +293,9 @@ impl HttpContext for StreamContext {
}
};
let provider_instance = self.get_provider_instance();
let provider = self.get_provider();
let mut deserialized_body = match provider_instance.parse_request(&body_bytes) {
let mut deserialized_body = match provider.parse_request(&body_bytes) {
Ok(deserialized) => deserialized,
Err(e) => {
debug!(
@ -356,10 +354,11 @@ impl HttpContext for StreamContext {
}
let llm_provider_str = self.llm_provider().provider_interface.to_string();
let hermes_llm_provider = Provider::from(llm_provider_str.as_str());
let hermes_llm_provider_id = ProviderId::from(llm_provider_str.as_str());
// convert chat completion request to llm provider specific request
let deserialized_body_bytes = match deserialized_body.to_provider_bytes(hermes_llm_provider)
let deserialized_body_bytes = match deserialized_body
.to_provider_bytes(hermes_llm_provider_id, ConversionMode::Compatible)
{
Ok(bytes) => bytes,
Err(e) => {
@ -529,42 +528,16 @@ impl HttpContext for StreamContext {
}
let llm_provider_str = self.llm_provider().provider_interface.to_string();
let hermes_llm_provider = Provider::from(llm_provider_str.as_str());
let _provider_id = ProviderId::from(llm_provider_str.as_str());
if self.streaming_response {
// Use the provider instance to parse streaming response
let provider_instance = self.get_provider_instance();
// TODO: Implement streaming response parsing with new provider structure
warn!(
"Streaming response parsing not yet fully implemented with new provider structure"
);
let streaming_events =
match provider_instance.parse_streaming_response(&body, &hermes_llm_provider) {
Ok(events) => events,
Err(e) => {
warn!(
"could not parse response: {}, body str: {}",
e,
String::from_utf8_lossy(&body)
);
return Action::Continue;
}
};
for event_result in streaming_events {
match event_result {
Ok(event) => {
if let Some(usage) = event.usage() {
self.response_tokens += usage.completion_tokens();
}
}
Err(e) => {
warn!("error in response event: {}", e);
continue;
}
}
}
// Compute TTFT if not already recorded
// For now, just compute TTFT and continue
if self.ttft_duration.is_none() {
// if let Some(start_time) = self.start_time {
let current_time = get_current_time().unwrap();
self.ttft_time = Some(current_time_ns());
match current_time.duration_since(self.start_time) {
@ -584,9 +557,9 @@ impl HttpContext for StreamContext {
}
} else {
debug!("non streaming response");
let provider_instance = self.get_provider_instance();
let response = match provider_instance.parse_response(&body, &hermes_llm_provider) {
Ok(de) => de,
let provider = self.get_provider();
let _response = match provider.parse_response(&body, ConversionMode::Compatible) {
Ok(response_box) => response_box,
Err(e) => {
warn!(
"could not parse response: {}, body str: {}",
@ -606,9 +579,9 @@ impl HttpContext for StreamContext {
}
};
if let Some(usage) = response.usage() {
self.response_tokens += usage.completion_tokens();
}
// TODO: Extract usage information from the response box
// For now, we'll skip this until we have a better way to handle Any types
warn!("Response token counting not yet implemented with new provider structure");
}
debug!(